最初のGAN(Genera tive Adversarial Networks)

最初のGAN(Genera tive Adversarial Networks)の実装を行いながら、生成モデルや敵対的生成ネットワークについて調べてみました。最初のGANは構造がシンプルなので基本的な仕組みの理解には最適です。

生成モデル(Generative model)と敵対的生成ネットワーク(Generative Adversarial Network:GAN)

まず、生成モデルについてです。深層学習には様々なモデルの分類があり、生成モデルもその一つとなります。生成モデル(Generative model)とは「サンプルデータ(テストデータ)は何かしらの確率分布に従っている」という仮定のもとで、与えられた標本分布から母集団の確率分布をモデル化し、その分布に従った新たなデータを生成するモデルとなります。例えば画像生成の場合、入力画像データがN枚ある場合、サンプルデータの標本分布(p(画像1)、p(画像2)、p(画像3)、…p(画像N))から母集団の確率分布を学習しモデル化します。生成時はモデル化された確率分布に従う新たな画像データを出力します。また、生成モデルには教師データはありませんので、教師なし学習の一種となります。

GANの基礎

サンプリングされた画像は[0,255] ^ (チャンネル数 x 縦 x 横)の高次元な空間に存在することになります。この空間中の与えられたサンプル画像の確率分布を探索することによって、サンプル画像に類似した画像が生成されるようになります。MNISTの場合はグレー画像の縦28、横28の画像となりますので、[0,255]の(28X28=)784乗の空間に全ての画像が含まれていることになります。このように高次元な場合は説明がとても困難ですので、ここでは1次元と仮定した図で説明します。

 1.テストデータの標本分布のイメージ。MNISTの画像は標本分布に従っている。

 2.学習により、母集団の確率分布に近似していく(赤:G(z))

3.近似した確率分布に従った画像を出力する

論文にも同様に1次元化した図が載ってますので、そちらも参考にしてください。

敵対的生成ネットワーク

敵対的生成ネットワーク(Generative Adversarial Network)とは、2014年にイアン・グッドフェローらが考案した、識別器(Discriminator、以下Dと表記)と生成器(Generator、以下Gと表記)の2つのネットワークを用いて相互に学習していくアーキティクチャで構成された生成モデルです。学習時にGはランダムノイズ(以下、z)を入力として、学習した確率分布に従って画像を生成、出力します。Dはサンプリングされた本物の画像(x)とGが生成した画像(G(z))をそれぞれ本物であるか、Gが作成した画像であるのかを識別するように学習を行います。学習の結果、GはDが本物画像(x)と生成画像(G(z))の違いが識別できなくなるようになり本物そっくりの(推測した確率分布に従った)画像を生成することが可能になります。GとDのネットワークがお互いの結果を用いて学習を繰り返す様をお互いが競い合っているという表現から敵対的生成ネットワークという名前が付けられています。

GANの学習と推論の構成について

GANの学習

$$
\max_G \max_D V(D, G) = \max_G \max_D \mathbb{E}_{\mathbb{x} \sim P_{data({\bf x})}}\ \ [\log D(\mathbb{x})]
+ \mathbb{E}_{\mathbb{z} \sim P_({\bf z})}\ \ [\log (1 – D(G(\mathbb{z})))]
$$

有名なGANの損失関数の論理式です。いろんなサイトで既に解説されているので、数式の詳細についてはここでは割愛しますが、要はDは本物である確率を最大に、GはDが偽物である確率を最大にするように学習していくことを数式で表しています。これは論文に記載されている「minmax法」の数式での表現となります。では、具体的なGANの学習の流れを見ていきます。

学習の流れは以下の通りです。

  1. Dの学習。テスト画像を入力したDの識別結果と正解ラベルのクロスバイナリーエントロピー損失を計算する。
  2. Dの学習。Gが生成した画像を入力したDの識別結果と偽物ラベルのクロスバイナリーエントロピー損失を計算する。
  3. Dの学習。1と2の損失を足し合わせて、Dに誤差逆伝播を行う
  4. Gの学習。Gが生成した画像を入力したDの識別結果と正解ラベルのクロスエントロピー損失を計算する。
  5. Gの学習。4の損失をGの誤差逆伝播を行う。

の繰り返しとなります。それぞれの学習の際に渡す、正解/偽物ラベルの違いに注意してください。これを実装したコードは以下のようになります。

for epoch in range(EPOCHS):
    for i, (imgs, _) in enumerate(dataloader):

        # サンプルノイズの生成(正規分布に従ったランダムな値)
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], LATENT_DIM))))

        # 判定ラベル
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)


        # テスト画像
        real_imgs = Variable(imgs.type(Tensor))

        # ---------------------
        #  Discriminatorの学習
        # ---------------------
        optimizer_D.zero_grad()

        # テスト画像と正解ラベルのペアで損失を計算
        real_loss = adversarial_loss(discriminator(real_imgs), valid)

        # G(z)と偽物ラベルのペアで損失を計算
        fake_imgs = generator(z)
        fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake)

        # それぞれのLossを加算してDiscriminatorの損失
        d_loss = real_loss + fake_loss

        # 損失の誤差逆伝播
        d_loss.backward()
        optimizer_D.step()

        # -----------------
        #  Generatorの学習
        # -----------------
        optimizer_G.zero_grad()

        # D(G(z))と正解ラベルで損失を計算
        g_loss = adversarial_loss(discriminator(fake_imgs), valid)

        # 損失の誤差逆伝播
        g_loss.backward()
        optimizer_G.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, EPOCHS, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % SAMPLE_INTERVAL == 0:
            save_image(fake_imgs.data[:25], IMAGES_PATH + "/%07d.png" % batches_done, nrow=5, normalize=True)
            # ログ情報の収集
            G_losses.append(g_loss.item())
            D_losses.append(d_loss.item())
            img_list.append(fake_imgs.data[:25])

学習過程の可視化について、面白いツィートがありました。こちらは一次元の正規分布と仮定したD(x)とG(z)の学習状況を擬似的に可視化したものです。

興味深かったので、記事元のソースを自分でも実行してみました。Twitterの記事と色は違いますが、緑のD(x)の識別結果によって橙のG(z)が青波線のP(Data)に近似していく様が面白いです。

GANの実装

最初のGANは実装がとてもシンプルです。pytorchを利用して実装してみました。最初のGANのGとDはそれぞれ全結合層で構成されています。

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(LATENT_DIM, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

Google Colabで動作するソースはGithubに置いてます。興味のあるかたは試してみてください。

最初のG(z)
100ループ目のG(z)
500ループ目のG(z)
1000ループ目のG(z)

どんどん綺麗な画像で生成され始めました。損失の推移と生成画像の変化をGIFにまとめてみました。

DとGの損失の推移
生成画像の推移

最後に

こちらを見ても分かる通り、現在、GANはたくさんの種類があります。今回、最初のGANを実装しながら、基本的な構成を見てみました。このあとも様々なGANをひとつずつ見ていきながらそれぞれの特徴を学習していきます。