Code360 powered by Coding Ninjas X Naukri.com. Code360 powered by Coding Ninjas X Naukri.com
Table of contents
1.
Introduction
2.
Training Progressive Growing GAN Model
3.
FAQs
4.
Key Takeaways
Last Updated: Jun 30, 2023

Progressive Growing GAN - Part 2

Author soham Medewar
0 upvote

Introduction

In the previous blog we defined the generator and the discriminator models now it's time to train those models. If you haven't read that article the link for part 1 is here.

Now we'll see the training part of the generator models in this blog.

Training Progressive Growing GAN Model

Loading Dataset

def load_real_samples(filename):
	data = load(filename)
	# extracting numpy array
	X = data['arr_0']
	# converting from ints to floats
	X = X.astype('float32')
	# scaling from [0,255] to [-1,1]
	X = (X - 127.5) / 127.5
	return X
You can also try this code with Online Python Compiler
Run Code

The next step is to retrieve a random sample of photos for updating the discriminator.

# selecting real samples
def generate_real_samples(dataset, n_samples):
	# choosing random instances
	ix = randint(0, dataset.shape[0], n_samples)
	# selecting images
	X = dataset[ix]
	# generating class labels
	y = ones((n_samples, 1))
	return X, y
You can also try this code with Online Python Compiler
Run Code

Next, we'll need a sample of latent points to use with the generator model to build synthetic images.

# generating points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
	# generating points in the latent space
	x_input = randn(latent_dim * n_samples)
	# reshaping into a batch of inputs for the network
	x_input = x_input.reshape(n_samples, latent_dim)
	return x_input
You can also try this code with Online Python Compiler
Run Code

The generate fake samples() function takes a generator model returns a batch of synthetic images. The class for all the generated images will be -1 to indicate that they are fake.

# generating points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
	# generating points in the latent space
	x_input = randn(latent_dim * n_samples)
	# reshaping into a batch of inputs for the network
	x_input = x_input.reshape(n_samples, latent_dim)
	return x_input
You can also try this code with Online Python Compiler
Run Code

The models are trained in two phases: a fade-in phase in which the transition from a lower-resolution to a higher-resolution image is made, and a normal phase in which the models are fine-tuned at a given higher resolution image.

# updating the alpha value on each instance of WeightedSum
def update_fadein(models, step, n_steps):
	# calculating current alpha (linear from 0 to 1)
	alpha = step / float(n_steps - 1)
	# updating the alpha for each model
 	for m in models:
  		for layer in m.layers:
    		if isinstance(layer, WeightedSum):
      			backend.set_value(layer.alpha, alpha)
You can also try this code with Online Python Compiler
Run Code

The technique for training the models for a certain training phase can then be defined.

One generator, discriminator, and composite model are updated on the dataset for a fixed number of training epochs during the training phase. The training phase could be a fade-in transition to a higher resolution, in which case update_fadein() must be called each iteration, or a standard tuning training phase, in which case no WeightedSum layers are present.

# training a generator and discriminator
def train_epochs(g_model, d_model, gan_model, dataset, n_epochs, n_batch, fadein=False):
	# calculating the number of batches per training epoch
	bpe = int(dataset.shape[0] / n_batch)
	# calculating the number of training iterations
	ns = bpe * n_epochs
	# calculating the size of half a batch of samples
	hb = int(n_batch / 2)
	# manually enumerating epochs
 	for x in range(ns):
  		# updating alpha for all WeightedSum layers when fading in new blocks
  		if fadein:
    		update_fadein([g_model, d_model, gan_model], x, ns)
  		# preparing real and fake samples
  		X_real, y_real = generate_real_samples(dataset, hb)
  		X_fake, y_fake = generate_fake_samples(g_model, latent_dim, hb)
  		# updating discriminator model
  		d_loss1 = d_model.train_on_batch(X_real, y_real)
  		d_loss2 = d_model.train_on_batch(X_fake, y_fake)
  		# updating the generator via the discriminator's error
  		z_input = generate_latent_points(latent_dim, n_batch)
  		y_real2 = ones((n_batch, 1))
  		g_loss = gan_model.train_on_batch(z_input, y_real2)
  		# summarizing loss on this batch
  		print('>%d, d1=%.3f, d2=%.3f g=%.3f' % (x+1, d_loss1, d_loss2, g_loss))
You can also try this code with Online Python Compiler
Run Code

For each training phase, we must then call the train epochs() function.

Scaling the training dataset to the appropriate pixel dimensions, such as 44 or 88, is the first step. This is accomplished via the scale_dataset() function, which takes a dataset and returns a scaled version.

# scaling images to preferred size
def scale_dataset(images, new_shape):
	images_list = list()
 	for image in images:
  		# resizing with nearest neighbor interpolation
  		new_image = resize(image, new_shape, 0)
  		# storing
  		images_list.append(new_image)
 	return asarray(images_list)
You can also try this code with Online Python Compiler
Run Code

We also need to preserve a plot of generated images and the current state of the generator model after each training run.

# generating samples and saving as a plot and saving the model
def summarize_performance(status, g_model, latent_dim, n_samples=25):
	# devising a name
	gs = g_model.output_shape
	nm = '%03dx%03d-%s' % (gs[1], gs[2], status)
	# generating images
	X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
	# normalizing pixel values to the range [0,1]
	X = (X - X.min()) / (X.max() - X.min())
	# plotting real images
	square = int(sqrt(n_samples))
	for i in range(n_samples):
  		pyplot.subplot(square, square, 1 + i)
  		pyplot.axis('off')
  		pyplot.imshow(X[i])
	# saving a plot to file
	f1 = 'plot_%s.png' % (nm)
	pyplot.savefig(f1)
	pyplot.close()
	# saving the generator model
	f2 = 'model_%s.h5' % (nm)
	g_model.save(f2)
	print('>Saved: %s and %s' % (f1, f2))
You can also try this code with Online Python Compiler
Run Code

The train() function takes the lists of defined models, as well as the list of batch sizes and the number of training epochs for the normal and fade-in phases for each degree of growth for the model, as inputs.

# training the generator and discriminator
def train(g_models, d_models, gan_models, dataset, latent_dim, e_norm, e_fadein, n_batch):
	# fitting the baseline model
	gn, dn, gan_n = g_models[0][0], d_models[0][0], gan_models[0][0]
	# scaling dataset to appropriate size
	gs = gn.output_shape
	sd = scale_dataset(dataset, gs[1:])
	print('Scaled Data', sd.shape)
	# training normal or straight-through models
	train_epochs(gn, dn, gan_n, sd, e_norm[0], n_batch[0])
	summarize_performance('tuned', gn, latent_dim)
	# processing each level of growth
	for i in range(1, len(g_models)):
  		# retrieving models for this level of growth
  		[gn, g_fadein] = g_models[i]
  		[dn, d_fadein] = d_models[i]
  		[gan_n, gan_fadein] = gan_models[i]
  		# scaling the dataset to appropriate size
  		gs = gn.output_shape
  		sd = scale_dataset(dataset, gs[1:])
  		print('Scaled Data', sd.shape)
  		# training fade-in models for next level of growth
  		train_epochs(g_fadein, d_fadein, gan_fadein, sd, e_fadein[i], n_batch[i], True)
  		summarize_performance('faded', g_fadein, latent_dim)
  		# training normal or straight-through models
  		train_epochs(gn, dn, gan_n, sd, e_norm[i], n_batch[i])
  		summarize_performance('tuned', g_normal, latent_dim)
You can also try this code with Online Python Compiler
Run Code

Finally, let us call the defined models, and call train() function to start the training process.

# number of growth phases, e.g. 6 == [4, 8, 16, 32, 64, 128]
n_blocks = 6
# size of the latent space
latent_dim = 100
# defining models
dm = define_discriminator(n_blocks)
# defining models
gm = define_generator(latent_dim, n_blocks)
# defining composite models
gan_m = define_composite(dm, gm)
# loading image data
dataset = load_real_samples('img_align_celeba_128.npz')
print('Loaded', dataset.shape)
# training model
n_batch = [16, 16, 16, 8, 4, 4]
# 10 epochs == 500K images per training phase
n_epochs = [5, 8, 8, 10, 10, 10]
train(gm, dm, gan_m, dataset, latent_dim, n_epochs, n_epochs, n_batch)
You can also try this code with Online Python Compiler
Run Code

Following are the results of the trained model.

For instance, plot 004x004-tuned.png shows a sample of images created after the first 4×4 training session. We can't see too far at this stage.

source

In the same way, these will be the results of 128×128 resolution.

source

FAQs

1. Do GANs need a lot of data?

GAN models are data-hungry, requiring massive amounts of varied and high-quality training samples to create high-fidelity natural pictures of various categories.

 

2. What is a discriminator in GAN?
In a GAN, the Discriminator is just a classifier. It attempts to distinguish between real data and data generated by the Generator. It might utilize any network architecture suitable for the sort of data it categorizes.

 

3. What is proGAN?

The term "ProGAN" refers to a type of generative adversarial network that was invented at NVIDIA. Progressive Growing of GANs For Improved Quality, Stability, and Variation.

 

4. What is Wasserstein loss?

The Wasserstein loss function aims to increase the difference between real and generated picture ratings. The following is a summary of the function as described in the paper: Critic Loss = [average critic score on real images] − [average critic score on fake images].

Key Takeaways

In this article, we have discussed the architecture and model implementation of proGAN. To understand the basics of GANs and ProGANs visit the first part of this article.

Want to learn more about Machine Learning? Here is an excellent course that can guide you in learning. 

Happy Coding!

Live masterclass