WGAN(Wasserstein GAN)

WGANとは、Wasserstein距離により損失計算を導入したGANとなります。GANの問題点として、「学習が安定しない」「勾配消失が起こりやすい」「モード崩壊が起こる」点が指摘されています。WGANではこのような問題の解決としてWasserstein距離により損失計算が提案されました。今回も最初のGANとの比較を行いながら実装してみることにします。

WGANの特徴

これまでのGANではD(x) – D(G(z))の分布の比較をJSダイバージェンスを用いて計算し、DとGの誤差逆伝播による学習を行なっていました。WGANでは、2つの分布間の損失をWasserstein距離として計算しています。DCGANでは識別器(D)と生成器(G)の構成が変わりましたが、WGANでは以下の点が変更となります。

最初のGANからの変更点

・最適化関数をAdamからRMSPropに変更

・識別器(D)の活性化関数を省略

・学習方法を変更(以下を参照)

学習の実装コードです。Dの損失をD(x)とD(G(z))で計算しているところ、Gの損失はD(G(z))で計算しているところ、Dの重みを一定の範囲でクリッピングしているところが見どころです。

for epoch in range(EPOCHS):
    for i, (imgs, _) in enumerate(dataloader):

        # サンプルノイズの生成(正規分布に従ったランダムな値)
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], LATENT_DIM))))

        # 判定ラベル
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

        # テスト画像
        real_imgs = Variable(imgs.type(Tensor))

        # G(z)
        fake_imgs = generator(z)

        # ---------------------
        #  Discriminatorの学習
        # ---------------------
        optimizer_D.zero_grad()

        # D(x)とD(G(z))の敵対的損失
        loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))

        # 損失の誤差逆伝播
        loss_D.backward()
        optimizer_D.step()

        # 重みを一定の範囲にクリッピング
        for p in discriminator.parameters():
            p.data.clamp_(-CLIP_VALUE, CLIP_VALUE)

        if i % N_CRITIC == 0:
            # -----------------
            #  Generatorの学習
            # -----------------
            optimizer_G.zero_grad()

            # G(z)
            gen_imgs = generator(z)

            # D(G(z))の敵対的損失
            loss_G = -torch.mean(discriminator(gen_imgs))

            loss_G.backward()
            optimizer_G.step()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, EPOCHS, batches_done % len(dataloader), len(dataloader), loss_D.item(), loss_G.item())
            )
            
            # ログ情報の収集
            G_losses.append(loss_G.item())
            D_losses.append(loss_D.item())


        if batches_done % SAMPLE_INTERVAL == 0:
            save_image(fake_imgs.data[:25], IMAGES_PATH + "/%07d.png" % batches_done, nrow=5, normalize=True)
            # ログ情報の収集
            G_losses.append(loss_G.item())
            D_losses.append(loss_D.item())
            img_list.append(fake_imgs.data[:25])

        batches_done += 1

WGANの実装

Google Colabで動作するソースはGithubに置いてます。興味のあるかたは試してみてください。

最初のG(z)
100ループ目のG(z)
500ループ目のG(z)
1000ループ目のG(z)

画像の生成が最初のGANともDCGANとも違いました。損失の推移と生成画像の変化をGIFにまとめてみました。

DとGの損失の推移
生成画像の推移

最後に

WGANはクリッピングされた重みが2極化するため勾配消失や勾配爆発が起きやすくなる課題があります。この課題を勾配に制約をつけることで改善したWGAN-GPが提案されました。こちらは次回実装してみます。

比較した最初のGANについてはこちらを参照してください。

WGANの論文はこちらです。