from IPython.display import HTML, IFrame
"https://thispersondoesnotexist.com", 400, 400) IFrame(
Introduction
This is a post about Generative Adversarial Networks (GANs). This post is very heavily influenced and borrows code from:
- Video from Luis Serrano
- Heavily borrowed code from this article on machine learning mastery
These folks deserve all the credit! I am writing this post mostly for my learning.
I’d highly recommend reading the above two mentioned resources.
Goal
The goal of GANs is to generate realistic data, i.e. data with similar statistics as the training data.
See below a “generated” face on https://thispersondoesnotexist.com
Refresh this page to get a new face each time!
These are people that do not exist but their faces have been generated using GANs.
Overall Block Diagram
Conceptually, GANs are simple.They have two main components:
- A discriminator: that tries to accurately tell generated and real data (from training data) apart
- A generator: that generates data given some random numbers
The goal of GANs is to use the generator to create realistic data such that the discriminator thinks it is real (coming from the training dataset)
The two components discriminator and generator are “fighting” where:
- the goal of the discriminator is to tell apart fake (generated) data from true data (from training set) even when the generator is fairly good
- the goal of the generator is to generate realistics data such that the discriminator thinks it is real data
Creating “true” distribution
Let us now create some data from the true/known distribution. We will be essentially creating a 2x2 matrix (image) as explained in Luis Serrano’s tutorial. The (0, 0) and (1, 1) position will be a high number between 0.8 and 1 whereas the other two positions (0, 1) and (1, 0) have values between 0 and 0.1
import numpy as np
import matplotlib.pyplot as plt
import mediapy as media
%matplotlib inline
40)
np.random.seed(import warnings
'ignore')
warnings.filterwarnings(import logging
import os
'TF_CPP_MIN_LOG_LEVEL'] = '3' # FATAL
os.environ['tensorflow').setLevel(logging.FATAL)
logging.getLogger(import tensorflow as tf
'ERROR')
tf.get_logger().setLevel(
42) tf.random.set_seed(
= 5000
SIZE = np.vstack((np.random.uniform(0.8, 1, SIZE),
faces 0., 0.1, SIZE),
np.random.uniform(0., 0.1, SIZE),
np.random.uniform(0.8, 1, SIZE))).T
np.random.uniform( faces.shape
(5000, 4)
def plot_face(f):
= f.reshape(2, 2)
f_reshape ="Greys") plt.imshow(f_reshape, cmap
def plot_faces(faces, subset=1):
= {
images f'Image={im}': faces[im].reshape(2, 2)
for im in range(len(faces))[::subset]
}=True, columns=8, height=80, cmap='Greys')
media.show_images(images, border
=700) plot_faces(faces, subset
Image=0
|
Image=700
|
Image=1400
|
Image=2100
|
Image=2800
|
Image=3500
|
Image=4200
|
Image=4900
|
The above shows some samples drawn from the true distibution. Let us also now create some random/noisy samples. These samples do not have any relationship between the 4 positions.
# Examples of noisy images
= np.random.randn(40, 4)
noise = np.abs(noise)
noise = noise/noise.max() noise
plot_faces(noise)
Image=0
|
Image=1
|
Image=2
|
Image=3
|
Image=4
|
Image=5
|
Image=6
|
Image=7
|
Image=8
|
Image=9
|
Image=10
|
Image=11
|
Image=12
|
Image=13
|
Image=14
|
Image=15
|
Image=16
|
Image=17
|
Image=18
|
Image=19
|
Image=20
|
Image=21
|
Image=22
|
Image=23
|
Image=24
|
Image=25
|
Image=26
|
Image=27
|
Image=28
|
Image=29
|
Image=30
|
Image=31
|
Image=32
|
Image=33
|
Image=34
|
Image=35
|
Image=36
|
Image=37
|
Image=38
|
Image=39
|
Creating the discriminator
Our discriminator is simple.
- It accepts as input a 4 dimensional input (the 2x2 image)
- It outputs a single number with sigmoid activation denoting the probability of:
- image being fake or generated by generator or belonging to class 0
- image being real or sampled from training dataset or belonging to class 1
- We use the binary cross entropy loss
To make the above crystal clear, I’ll use the following gist to draw this NN
from draw_nn import draw_neural_net
= plt.figure(figsize=(4, 4))
fig = fig.gca()
ax 'off')
ax.axis(.1, 0.9, .1, .6, [4, 1])
draw_neural_net(ax,
plt.tight_layout()"Discriminator NN"); plt.title(
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.optimizers import Adam
= Sequential([
discriminator 1,activation='sigmoid', input_shape=(4, )),
Dense(
])
= "Discriminator"
discriminator._name
compile(
discriminator.=Adam(0.001),
optimizer='binary_crossentropy'
loss )
discriminator.summary()
Model: "Discriminator"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 1) 5
=================================================================
Total params: 5
Trainable params: 5
Non-trainable params: 0
_________________________________________________________________
As expected, the discriminator has 5 parameters (4 weights coming from the 4 inputs to the output node and 1 bias term added). Now, let us create the generator.
Creating the generator
Let us now create the generator model. We create a very simple one
- It accepts as input a single random number
- It creates a vector of size 4
The illustration below shows this network. It should be noted that the single random input is an arbitrary choice. We could use any number really!
= plt.figure(figsize=(4, 4))
fig = fig.gca()
ax 'off')
ax.axis(.1, 0.9, .1, .6, [1, 4])
draw_neural_net(ax,
plt.tight_layout()"Generator NN"); plt.title(
from keras.layers import ReLU
= Sequential([
generator 4, input_shape=(1, )),
Dense(=1.0)
ReLU(max_value
])
= "Generator"
generator._name
compile(
generator.=Adam(0.001),
optimizer='binary_crossentropy'
loss )
generator.summary()
Model: "Generator"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 4) 8
_________________________________________________________________
module_wrapper (ModuleWrappe (None, 4) 0
=================================================================
Total params: 8
Trainable params: 8
Non-trainable params: 0
_________________________________________________________________
We can verify that the network has 8 parameters (4 weights and one bias value per output node)
Generating samples from Generator
We can now use our generator to generate some samples and plot them.
def gen_fake(n_samples):
= np.random.randn(n_samples, 1)
x_input = generator.predict(x_input)
X = np.zeros((n_samples, 1))
y return X, y
As expected, the samples look random, without any specific pattern and do not resemble the training data as our generator is untrained. Further, it is important to reiterate that the class associated with the fake samples generated from the generator is 0. Thus, we have the line np.zeros((n_samples, 1))
in the code above.
20)[0]) plot_faces(gen_fake(
Image=0
|
Image=1
|
Image=2
|
Image=3
|
Image=4
|
Image=5
|
Image=6
|
Image=7
|
Image=8
|
Image=9
|
Image=10
|
Image=11
|
Image=12
|
Image=13
|
Image=14
|
Image=15
|
Image=16
|
Image=17
|
Image=18
|
Image=19
|
Sampling from the Real (Train) Dataset
def gen_real(n_samples):
= np.random.randint(0, faces.shape[0], n_samples)
ix = faces[ix]
X = np.ones((n_samples, 1))
y return X, y
20)[0]) plot_faces(gen_real(
Image=0
|
Image=1
|
Image=2
|
Image=3
|
Image=4
|
Image=5
|
Image=6
|
Image=7
|
Image=8
|
Image=9
|
Image=10
|
Image=11
|
Image=12
|
Image=13
|
Image=14
|
Image=15
|
Image=16
|
Image=17
|
Image=18
|
Image=19
|
We can clearly see the pattern in the images coming from the training dataset.
Training the GAN
The block diagram below shows the main idea behind training GANs. The procedure is similar to alternative least squares.
def define_gan(g_model, d_model):
= False
d_model.trainable = Sequential()
model
model.add(g_model)
model.add(d_model)= Adam(lr=0.001)
opt compile(loss='binary_crossentropy', optimizer=opt)
model.return model
= define_gan(generator, discriminator) gan_model
It is important to note that we will train two networks:
- Discriminator on
- Fake data (class 0)
- Real data (class 1)
- Combined model consisting og Generator + Discriminator (where the Discriminator is is fixed) on
- Fake data (class 0) posing as real data (class 1) to the model
Thus, we do not train on the generator separately.
= {}
samples_saved = {}
losses = 1000
N_ITER = N_ITER//10
STEP for i in range(N_ITER):
# Generate some fake data
= gen_fake(2)
X_fake, y_fake = gen_real(2)
X_real, y_real
= np.vstack((X_fake, X_real)), np.vstack((y_fake, y_real))
X, y
# Discriminator
= discriminator.train_on_batch(X, y)
d_loss
# Generator
= 4
n_samples = gan_model.train_on_batch(np.random.randn(n_samples, 1), np.ones(n_samples))
g_loss= {'Gen. loss':g_loss, 'Disc. loss':d_loss}
losses[i]
# Save 5 samples
= gen_fake(5)[0]
samples_saved[i]
if i%STEP==0:
# Save model
f"models/gen-{i}")
generator.save(print("")
print("Iteration: {}".format(i))
print("Discriminator loss: {:0.2f}".format(d_loss))
print("Generator loss: {:0.2f}".format(g_loss))
WARNING:absl:Found untraced functions such as re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_fn, re_lu_layer_call_fn, re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.
WARNING:absl:Found untraced functions such as re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_fn, re_lu_layer_call_fn, re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.
WARNING:absl:Found untraced functions such as re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_fn, re_lu_layer_call_fn, re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.
WARNING:absl:Found untraced functions such as re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_fn, re_lu_layer_call_fn, re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.
WARNING:absl:Found untraced functions such as re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_fn, re_lu_layer_call_fn, re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.
WARNING:absl:Found untraced functions such as re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_fn, re_lu_layer_call_fn, re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.
WARNING:absl:Found untraced functions such as re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_fn, re_lu_layer_call_fn, re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.
WARNING:absl:Found untraced functions such as re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_fn, re_lu_layer_call_fn, re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.
WARNING:absl:Found untraced functions such as re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_fn, re_lu_layer_call_fn, re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.
WARNING:absl:Found untraced functions such as re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_fn, re_lu_layer_call_fn, re_lu_layer_call_and_return_conditional_losses, re_lu_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.
Iteration: 0
Discriminator loss: 0.61
Generator loss: 0.72
Iteration: 100
Discriminator loss: 0.61
Generator loss: 0.75
Iteration: 200
Discriminator loss: 0.59
Generator loss: 0.76
Iteration: 300
Discriminator loss: 0.57
Generator loss: 0.72
Iteration: 400
Discriminator loss: 0.60
Generator loss: 0.77
Iteration: 500
Discriminator loss: 0.61
Generator loss: 0.68
Iteration: 600
Discriminator loss: 0.64
Generator loss: 0.66
Iteration: 700
Discriminator loss: 0.60
Generator loss: 0.71
Iteration: 800
Discriminator loss: 0.65
Generator loss: 0.66
Iteration: 900
Discriminator loss: 0.70
Generator loss: 0.63
Convergence
import pandas as pd
= pd.DataFrame(losses)
losses_df ;
losses_df.T.plot()"Iteration Number"); plt.xlabel(
You might epxect that over time the generator loss reduces as it becomes better and correspodingly the discriminator has a harder time!
Generating some “fake” images from the trained generator
20)[0]) plot_faces(gen_fake(
Image=0
|
Image=1
|
Image=2
|
Image=3
|
Image=4
|
Image=5
|
Image=6
|
Image=7
|
Image=8
|
Image=9
|
Image=10
|
Image=11
|
Image=12
|
Image=13
|
Image=14
|
Image=15
|
Image=16
|
Image=17
|
Image=18
|
Image=19
|
You could not tell, right! The generator has been trained well!
Visualising evolution of generator
Let us now visualise the evolution of the generator. To do so, we use the already saved generator models at different iterations and feed them the same “random” input.
= {}
o for i in range(0, N_ITER, STEP):
for inp in [0., 0.2, 0.4, 0.6, 1.]:
f'It:{i}-Inp:{inp}'] = load_model(f"models/gen-{i}").predict(np.array([inp])).reshape(2, 2) o[
=True, columns=5, height=80, cmap='Greys') media.show_images(o, border
It:0-Inp:0.0
|
It:0-Inp:0.2
|
It:0-Inp:0.4
|
It:0-Inp:0.6
|
It:0-Inp:1.0
|
It:100-Inp:0.0
|
It:100-Inp:0.2
|
It:100-Inp:0.4
|
It:100-Inp:0.6
|
It:100-Inp:1.0
|
It:200-Inp:0.0
|
It:200-Inp:0.2
|
It:200-Inp:0.4
|
It:200-Inp:0.6
|
It:200-Inp:1.0
|
It:300-Inp:0.0
|
It:300-Inp:0.2
|
It:300-Inp:0.4
|
It:300-Inp:0.6
|
It:300-Inp:1.0
|
It:400-Inp:0.0
|
It:400-Inp:0.2
|
It:400-Inp:0.4
|
It:400-Inp:0.6
|
It:400-Inp:1.0
|
It:500-Inp:0.0
|
It:500-Inp:0.2
|
It:500-Inp:0.4
|
It:500-Inp:0.6
|
It:500-Inp:1.0
|
It:600-Inp:0.0
|
It:600-Inp:0.2
|
It:600-Inp:0.4
|
It:600-Inp:0.6
|
It:600-Inp:1.0
|
It:700-Inp:0.0
|
It:700-Inp:0.2
|
It:700-Inp:0.4
|
It:700-Inp:0.6
|
It:700-Inp:1.0
|
It:800-Inp:0.0
|
It:800-Inp:0.2
|
It:800-Inp:0.4
|
It:800-Inp:0.6
|
It:800-Inp:1.0
|
It:900-Inp:0.0
|
It:900-Inp:0.2
|
It:900-Inp:0.4
|
It:900-Inp:0.6
|
It:900-Inp:1.0
|
We can see above the improvement of the generation over the different iterations and different inputs! That is it for this article. Happing GANning.