DCGAN “Deep Convolutional Generative Adversarial Network” に触れる Part(1)

colaboratory

はじめに

こんにちは、swim-loverです。Pythonをとりかかりとして、Pytorch Tensorflowで機械学習を勉強していいます。「使いながら覚える」をコンセプトに勉強しています。

このBlogでは、識別モデルの物体検出について何度か取り扱ってきましたが、今回、生成モデルの一つである、GAN”ギャン”についても試してみます。

参考WebSite

Deep Convolutional Generative Adversarial Network  |  TensorFlow Core
GAN (Generative Adversarial Networks):敵対的生成ネットワーク
オイラはCG上がりの人間なので、ディープラーニングへの興味は画像認識のような識別系よりもっぱら生成系なのである。 最近はCG系の学会でもお馴染みになりつつあるGAN(敵対的生成ネットワーク)についてちゃんと知りたくて、その前知識としてニュー

GAN 概要

GAN, 生成的敵対的ネットワークは、GENERATORとDISCRIMINATORの2つのモデルが同時に学習することが特徴となっています。

GENERATORはノイズベクトルから本物に近い画像を作り出すように学習し、DISCRIMINATORは、偽物か本物かを判別することを学習します。

GAN 学習

学習は、DISCRIMINATORとGENERTORが交互に学習していきます。

DCGAN 実践

早速、GANを動作させてみます。実装サンプルは、DCGAN “Deep Convolutional GAN”を用いいます。

GoogleColabの実装サンプルを自分なりに理解しながら進めていきます。

Google Colab

DCGAN セットアップ

実装サンプルそのままになります。

import tensorflow as tf
tf.__version__
# To generate GIFs
!pip install imageio
!pip install git+https://github.com/tensorflow/docs

8行目でplot_modelを追加してみました。機械学習は初心者なので、モデルの確認できるとより理解が進むと思いました。

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
from tensorflow.keras.utils import plot_model
import time

from IPython import display

DCGAN データセット準備

実装サンプルそのままになります。tensorflowのAPIを用いて、MINISTデータをダウンロードします。

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

DCGNA Generatorモデルの作成

Generatorのモデルは、Keras Sequential APIを用いて作成しています。

Kerasは、Tensorflowの高レベルAPIとのこと。

Sequential APIは、シーケンシャルにmodel.addによってモデルを繋げることができるようです。

3~7行目に関して

  • Denseは全結合のニューラルネットを作成するようです。input_shapeで入力データの形式を指定します。7x7x256=12544のユニット数を指定しています。
  • BatchNormalization()はノーマライズ、正規化処理になります。データの正規化処理自体は、特段珍しいものではないと思いますが、BatchNormalization()はDeep Learningではかなり有効な処理のようです。
  • LeakyReLU()は、活性化関数ReLUの特別バージョンのようです。
  • Reshape()で、7x7x256個のデータの整形します。

10~13行目に関して

Conv2DTranspose()は、畳込み処理になります。Conv2Dは、通常の畳込み処理、Conv2Dの入力データをずらしながら、係数データとの積和演算を行います。係数の値によって出力されるデータが変わってきます。以下のQuitaの説明サイトがとても丁寧に解説してくれていました。

kerasのConv2D(2次元畳み込み層)について調べてみた - Qiita
やりたいことkerasのConv2Dを理解したいそれにより下記のようなコードを理解したい(それぞれの関数が何をやっているのか?や引数の意味を説明できるようになりたい)。from keras i…

Conv2DTranspose()は、アップサンプリング用途で用いられるようです。ConvolutionMatrixを転置して演算しています。以下のサイトが大変丁寧背説明してくれていました。

Up-sampling with Transposed Convolution
If you’ve heard about the transposed convolution and got confused about what it actually means, this article is written for you.

padding=’same’を指定すると、stridesの指定値で整数倍にアップサンプリングしてくれるようです。

Conv2DTranspose()を3回実行しているので、7×7(stride=1)->14×14(stride=2)->28×28(stride=2)の出力にアップサンプリングされているようです。

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

最後にGeneratorのモデルとGeneratorが生成した画像を確認してみました。

generator=make_generator_model()
plot_model(
    generator,
    show_shapes=True,
)

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')

まとめ

今回、生成モデルの一つである、GANについて調べてみました。第一回は、DCGANのGeneratorまでを確認してみました。

次回では、Discriminatorについて確認してみます。

コメント

タイトルとURLをコピーしました