今週はGenerative Adversarial Networks(以下GAN)とDeep Convolutional GAN(以下DCGAN)を何となく読んだので、その辺の概要をまとめておく。個人的な備忘録の側面が強いので、あまり色んな人が読むことは想定しておらず、ある程度前提知識を要求している気がするし、参考になるかわからないけれど読む人は参考程度に。もとの論文がいつでも一番正確なはず(GAN、DCGAN)。そもそも既に多くの人が解説記事を書いているし、多くの人にとってこの記事はあまり価値を持たないんじゃないかなと思う*1。
GANとDCGANの位置づけ*2
GANの目的は与えられた学習データから生成モデルを推定すること。つまりは訓練データから学習することによって、確率分布に基づく乱数を吐くが如く、それっぽいデータを吐くようなサンプラーをつくることができるということだろう。たとえばパラメタ集合によって定まる何らかの確率分布を定義して、訓練データ集合に対して、最尤推定などをして(ここでは尤度を最大化するパラメタ集合)あたりをサンプラーとして使おうみたいなことがまず考えられるのかなと思う。
当然サンプラーの表現力はの構成方法に大きく依存していて*3、例えば画像のような超絶複雑な何かをいい感じに表現するためには、統計の教科書に乗ってるような、よく知られた確率分布なんかをとしてそのまま使っても表現力が全然たりなくて、『じゃあそこに多層パーセプトロンとかを組み込んじゃえば超複雑な生成モデルが構成できるのでは?』なんて考える人がいてもおかしくはないだろう。変分オートエンコーダあたりはまさにそういう感じで、GAN(とDCGAN)も概ねこういう考え方に沿っているのかなと思う。
Adversarialな問題設定
ちなみに英語力の無い僕は、Adversarialという単語がわからなかった。反対の、敵対する、対立するなどの意味を持つらしい。 さて、まず目的関数を眺めてみる。
ここで、はデータを生成する分布では訓練データのサンプルと見て良いと思う*4。は何らかの確率分布(の事前分布としての役割っぽい)、関数とは多層パーセプトロンによって構成される関数。ということでこれらは多くのパラメータを含み、これらのパラメータを動かしていって上述の目的関数をどうにかしようという話。最終的には適当な乱数をから生成し、多層パーセプトロンにブチ込むと学習データらしきもの(たとえば画像)が出てくる。そういう仕組みらしい。
目的関数がわかったので、これを最大化したいのか最小化したいのか、GANのユニークなところはたぶんここで、以下のような問題設定になっている。
に関しては目的関数を最大化、一方でに関しては最小化しようとする。何故か対立していて、これがたぶんAdversarialとつく所以。 とはDiscriminative modelとGenerative modelを表していてるとのことで、論文を眺めていたとき、後者はともかく前者の意味がさっぱりわからなかった。Weblioで調べてみるなどしても、『慎重な判断を表現するさま』『微細に区別することができる』などと書いてあって、何を判断したり区別するんだよ感がすごい。というわけでまずのお気持ちを考えることにした。
は何をしたいのか
先程の目的関数、何となく既視感があるなと思ったらベルヌーイ分布の対数尤度のように見えてきた。そう聞くとなんとなくdiscriminativeな感じがしてくる。つまりは2値分類。たとえば、 、 (もちろんどちらもi.i.d)とするととが十分大きければ、
と書ける。ここまで言及してこなかったけど、の出力は1との減算が行われていることから分かる通りスカラー値で、もしこのの値域が(0,1)なら、は1、は0にできるだけ近づくようなにしてあげれば、目的関数を最大化できそうだ。でも何か論文にはアクティベーション関数はmaxout使うって書いてあって、そういう値域の縛りは無さそうだった*5。でもまぁいずれにせよは大きな値、は小さな値を返すを選ぶのが良さそうだ。ここまでの話を一言でまとめると、つまりははとを上手に見分ける1次元の特徴量を吐く関数になろうとしているようだ。
じゃあは何をしたいのか
前述の通り、先程の関数を最大化するということは、とを上手に見分けられるようにすることに対応する。一方ではこれを最小化しようとする。つまりは何とかしてを騙して、と判別がつかないものをから生成しようとするのがのようだ。こういう回りくどいことをすることで、過学習を避けてるのかなと思った*6。
一方でを騙そうとしているに基づいて生成されるサンプルの分布が、の良い近似にそもそもなるのだろうかというのは自然な疑問だけれど、に関してとなるときが大域的最適解であり、またこのときに限ることの証明がついていた*7。
アルゴリズムについて細かい話
基本はと交互にミニバッチでSGD的な反復法を行う(論文にはmomentum使ったって書いてあった)。 パラメータというのがあって、を回更新してからを1回更新するみたいなことが書いてあったのだが、でやったって書いてあって意義があまりわからなかった。
あとは最初のうちはの更新の際、がサチりやすいから、これを最小化するのではなく、を最大化するのが良いとあった。たぶん情報落ち対策*8。
DCGANにおける改善点
DCGANは、要は畳み込みネットワークをGANで構成するとなかなか上手く行かないから、以下のことをすると安定するよ、という話。
- Pooling layerをのものはStrided convotuionに、のものはfractional-strided convolutionにする
- Batch normを使う。
- 深いアーキテクチャの場合は全結合の隠れレイヤーを取り除く
- のアウトプット以外の全てのアクティベーションはReLUにする
- のアクティベーションはLeakyReLUにする
GANが言うこと聞かなくて試行錯誤したんだろうね。