DCGAN (Deep Convolutional GAN)
DCGAN とは、Deep Convolutional GANの略で、畳み込みニューラルネットワークによる敵対的生成を意味しています。今回はDCGANを実装してみました。DCGANのチュートリアルではデータセット「CelebA」を使われることが多いのですが、最初のGANとの比較のためにMNISTデータを利用しています。
DCGANの特徴
最初のGANと比較して、大きな特徴点は識別器(Discriminator、以下Dと表記)と生成器(Generator、以下Gと表記)が畳み込み演算を採用していることです。最初のGANはDおよびGは全結合層となっていましたが、DCGANではDは通常の畳み込み、Gは逆畳み込みのレイヤで構成されています。それ以外の学習等に関しては最初のGANと同等となります。
DCGANの実装
DとGの実装です。D:畳み込み、G:逆畳み込みで構成されています。
# Generatorの定義
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.init_size = IMG_SIZE // 4
self.l1 = nn.Sequential(nn.Linear(LATENT_DIM, 128 * self.init_size ** 2))
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, CHANNELS, 3, stride=1, padding=1),
nn.Tanh(),
)
def forward(self, z):
out = self.l1(z)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
# Discriminatorの定義
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, bn=True):
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
return block
self.model = nn.Sequential(
*discriminator_block(CHANNELS, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
# The height and width of downsampled image
ds_size = IMG_SIZE // 2 ** 4
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
def forward(self, img):
out = self.model(img)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
return validity
Google Colabで動作するソースはGithubに置いてます。興味のあるかたは試してみてください。
画像の生成が明らかに最初のGANと違いました。損失の推移と生成画像の変化をGIFにまとめてみました。
最後に
損失の推移については少し疑問を残す形となっていますが、GANと比べて明らかに綺麗な画像が生成されていることがわかりました。DCGANでは畳み込みと逆畳み込み(Fractionally strided convolution)が提案導入されていることが大きなポイントで、これによって生成精度が大きく向上しました。以降に続く様々なGANも畳み込みの採用を行なっていることから、GANの成長に大きく寄与した提案だと考えています。
比較した最初のGANについてはこちらを参照してください。
DCGANの論文はこちらです。