Master WGANs: Boost Image, Audio, and Text Generation with Wasserstein GANs

Master WGANs: Boost Image, Audio, and Text Generation with Wasserstein GANs

Table of Contents

Introduction

Wasserstein GANs (WGANs) are revolutionizing the world of generative adversarial networks (GANs) by using the Wasserstein distance to enhance stability and output quality. Unlike traditional GANs, WGANs solve issues like mode collapse and unstable training by introducing key modifications, including weight clipping and gradient penalties. These innovations ensure smoother training, enabling higher-quality results in image, audio, and text generation. In this article, we dive into how WGANs improve generative models and how they can boost your image, audio, and text generation projects.

What is Wasserstein Generative Adversarial Networks (WGANs)?

Wasserstein Generative Adversarial Networks (WGANs) are an improved version of traditional GANs designed to generate more realistic data while addressing issues like unstable training and poor convergence. They achieve this by using a different loss function, called Wasserstein distance, which makes the training process more stable and helps produce higher-quality results. While they may take longer to train, WGANs are used in tasks like image, audio, and text generation, offering better reliability and output quality.

Understanding WGANs

Generative Adversarial Networks (GANs) work by using two main probability distributions that are really important for how they operate. First, you’ve got the probability distribution of the generator ( Pg ), which shows the distribution of the outputs that the model creates. Then, there’s the probability distribution of the real images ( Pr ), which corresponds to the actual, real-world data the model is trying to copy. The main goal of a GAN is to make sure these two distributions—the one for real data and the one for generated data—are as close as possible. This way, the generator ends up making realistic, high-quality data that looks a lot like the real data.

To measure how far apart these two distributions are, there are a few mathematical methods that can be used. Some of the common ones include Kullback–Leibler (KL) divergence, Jensen–Shannon (JS) divergence, and Wasserstein distance. While Jensen–Shannon divergence is used a lot in basic GAN models, it has some serious problems, especially when working with gradients. These problems can cause the model to train in an unstable way, leading to poor results. To fix this, Wasserstein distance is used in Wasserstein GANs ( WGANs ) to improve the model’s stability and help it train more effectively. The Wasserstein distance gives a more meaningful and consistent measure of how close the generated data is to the real data, making the model perform better overall.

The formula for Wasserstein distance is shown below. It helps explain how this metric works with the generator and discriminator. In this formula, the “max” value represents the constraint placed on the discriminator. This constraint is key for ensuring that the discriminator, also known as the “critic” in WGANs , does its job right. The reason it’s called a “critic” instead of a “discriminator” is because it doesn’t use the sigmoid activation function. In traditional GANs, the sigmoid function limits the output to either 0 (fake) or 1 (real). In WGANs , the critic outputs a range of values, which lets it give a more detailed and nuanced evaluation of the data’s quality.

The term “critic” in WGANs is important to understand as it differs from the traditional “discriminator” used in regular GANs.

Here’s how to understand the formula: the first term represents the real data, and the second term represents the generated data. The critic’s goal is to maximize the difference between these two terms, meaning it wants to clearly distinguish between real and fake data. On the other hand, the generator’s job is to minimize that difference by creating data that looks as much like the real thing as possible, making it seem “real” in the eyes of the critic. So, while the critic aims to make the distinction between real and fake as clear as possible, the generator works hard to reduce that gap, constantly improving the generated data’s quality.

Read more about Wasserstein Generative Adversarial Networks in the detailed exploration of its implementation and training in the article Improved Training of Wasserstein GANs.

Learning the details for the implementation of WGANs

The original setup of the Wasserstein Generative Adversarial Network (WGAN) goes into great detail about how the architecture works, and its main goal is to make the training of GANs better. A key part of this architecture is the “critic,” which is responsible for providing a useful way of evaluating the output from the generator. The critic helps stabilize the training by making it easier to tell the difference between real and fake data.

However, the initial paper that introduced WGAN pointed out some challenges with the weight clipping method used in the architecture. Weight clipping was supposed to help control the critic’s function, but it didn’t always work as well as expected. For example, when the weight clipping was set too high, it caused longer training times. This happened because the critic needed more time to adjust to the weights in the network. On the flip side, if the weight clipping was set too low, it led to vanishing gradients—this is a common problem that pops up when the network has a lot of layers. It was especially noticeable in situations where batch normalization wasn’t used or when Recurrent Neural Networks (RNNs) were involved.

To solve these problems and improve WGAN training, a major update came in the paper titled “Improved Training of Wasserstein GANs.” Instead of weight clipping, the paper suggested using a gradient penalty method, which helped make training smoother. The gradient penalty approach is now the go-to method for training WGANs, and it works much better in practice.

The WGAN-GP (Wasserstein GAN with Gradient Penalty) method adds a regularization term to the loss function, called the gradient penalty. This penalty ensures that the L2 norm of the gradients of the discriminator stays close to 1. By doing this, the training process becomes faster and more stable. The algorithm laid out in the paper defines a few important parameters. For example, the lambda value controls how strong the gradient penalty is, while the "n-critic" setting tells you how many times the critic should train before updating the generator. The alpha and beta values are constraints for the Adam optimizer, which helps fine-tune the training process.

To add the gradient penalty, an interpolation image is created, which is a mix of real and generated images. This image is then passed through the discriminator to calculate the gradient penalty. This technique helps to meet the Lipschitz continuity constraint needed to train the WGAN model correctly. The training process keeps running until the generator is producing high-quality, realistic data.

Next, we’ll dive into how to practically set up the WGAN architecture with the gradient penalty method to tackle the MNIST project. With the gradient penalty in place, we’ll be able to boost both the quality and stability of the model, ensuring the generator continues to deliver accurate results over time.

For a deeper understanding of gradient penalty methods and their application in WGANs, check out the detailed explanation in the research paper Improved Training of Wasserstein GANs.

Construct a project with WGANs

In this part of the article, we’re going to put our knowledge of WGANs into practice by building out the networks, focusing on how they work and how to set them up. We’ll make sure to use the gradient penalty method during training to keep things running smoothly. To do this, we’ll use the WGAN-GP (Wasserstein GAN with Gradient Penalty) approach, which comes straight from the official Keras website. Most of the code will be adapted from there, so we’re in good hands!

Importing the essential libraries

First, we’ll need some tools to get things started. We’ll be using TensorFlow and Keras for building the WGAN architecture. These libraries are perfect for efficiently setting up and training our neural networks. If you’re not already familiar with them, no worries—feel free to check out my previous articles where I dive into these in more detail. We’ll also be bringing in numpy for handling array computations and matplotlib for making visualizations if needed.


import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np

Defining Parameters and Loading Data

Now, let’s define some of the key parameters we’ll be using throughout the WGAN architecture. We’ll also create some reusable neural network building blocks, like the convolutional block and the upsample block. On top of that, we’ll load the MNIST dataset—this will be our sample data for generating images of digits.To kick things off, let’s define the image size for the MNIST data. Each image is 28 x 28 pixels, and it has just one color channel, so it’s grayscale. We’ll also define a base batch size and the noise dimension, which the generator will use to create the digit images.


IMG_SHAPE = (28, 28, 1)
BATCH_SIZE = 512
noise_dim = 128

Next, we’ll load the MNIST dataset, which is easily available from TensorFlow and Keras’ free example datasets. This dataset has 60,000 images, and we’ll split them into training images, training labels, test images, and test labels. We’ll also normalize the images so that they fit within a range that’s easier for our training model to handle.


MNIST_DATA = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = MNIST_DATA.load_data()
print(f”Number of examples: {len(train_images)}”)
print(f”Shape of the images in the dataset: {train_images.shape[1:]}”)
train_images = train_images.reshape(train_images.shape[0], *IMG_SHAPE).astype(“float32”)
train_images = (train_images – 127.5) / 127.5

Once the data is prepped, we can start defining the neural network blocks that will help build both the discriminator and generator models. First, we’ll make a function for the convolutional block, which will be used mainly in the discriminator. The convolutional block function will handle a few different parameters for setting up the 2D convolution layer, and it also gives us the option to add batch normalization or dropout layers. These extra layers can help the model generalize better and prevent overfitting.


def conv_block(x, filters, activation, kernel_size= (3, 3), strides= (1, 1), padding=”same”,
     use_bias=True, use_bn=False, use_dropout=False, drop_value=0.5):
    x = layers.Conv2D(filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias)(x)
    if use_bn:
        x = layers.BatchNormalization()(x)
    x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
    return x

Similarly, we’ll create a function for the upsample block, which will be used mostly in the generator. This block is responsible for increasing the spatial resolution of the image, essentially upscaling it. Just like the convolutional block, we can optionally add batch normalization or dropout layers to it. Plus, each upsample block is followed by a regular convolutional layer, which ensures the quality of the generated images.


def upsample_block(x, filters, activation, kernel_size= (3, 3), strides= (1, 1), up_size= (2, 2), padding=”same”,
     use_bn=False, use_bias=True, use_dropout=False, drop_value=0.3):
    x = layers.UpSampling2D(up_size)(x)
    x = layers.Conv2D(filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias)(x)
    if use_bn:
        x = layers.BatchNormalization()(x)
    if activation:
        x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
    return x

In the next sections, we’ll put both the convolutional block and the upsample block to work in building the generator and discriminator models. These models will be the heart of our WGAN architecture, and we’ll train them to generate realistic images from the MNIST dataset. Let’s dive into how to create these models!

For more insights into building and training generative models like WGANs, refer to the comprehensive guide on Wasserstein GANs with Gradient Penalty.

Importing the essential libraries

To build the WGAN architecture, we’ll be using TensorFlow and Keras, two powerful deep learning frameworks. These tools make it much easier to build and train neural networks. TensorFlow is an all-in-one, open-source platform for machine learning, while Keras is its high-level API, designed to help you create complex models without too much hassle. If you’re not yet familiar with them, I’d definitely recommend checking out my previous articles, where I dive deep into these topics, including how to use them effectively for machine learning tasks.

Besides TensorFlow and Keras, we’ll also bring in numpy. This library is super important for handling numerical data and performing array-based computations, which are pretty common in machine learning workflows. Numpy makes it easy to handle large datasets and do the math operations needed for neural networks. So, it’s a must-have!

On top of that, we’ll use matplotlib, which is a popular plotting library for Python. It’ll help us visualize the results of our experiments when needed. Visualization is key to understanding how the training is going and evaluating the quality of generated images, especially when you’re working with generative models like WGANs.

Here’s the code that shows how to import all these libraries:

import tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras import layersimport matplotlib.pyplot as pltimport numpy as np

With this setup, you’ll have all the tools you need to develop and train the WGAN model. Combining TensorFlow, Keras, numpy, and matplotlib gives us everything we need to create a solid, efficient machine learning model.

For more information on setting up deep learning frameworks and essential libraries, refer to the detailed guide on Installing TensorFlow and Keras for Machine Learning Projects.

Defining Parameters and Loading Data

In this section, we’re going to walk through the essential steps to set up the WGAN network. We’ll start by defining some key parameters, building important neural network blocks that we’ll use throughout the project, and loading up the MNIST dataset. These steps are all necessary for getting the model up and running smoothly.

Let’s kick things off by defining a few basic parameters that are crucial for both the MNIST dataset and the WGAN model. The MNIST dataset is made up of 28×28 grayscale images, each with a single channel. So, we can define the image dimensions like this:


IMG_SHAPE = (28, 28, 1)

Next up, we have the BATCH_SIZE . This is the number of images that the model processes in one go. A typical batch size is 512—this helps the model learn efficiently without maxing out your memory. We’ll also need to define a noise_dim . This represents the dimension of the latent space that the generator will use to sample and create new images. Think of it as the “ingredients” for generating fresh new images. Here’s how it looks in code:


BATCH_SIZE = 512
noise_dim = 128

Now, let’s get to the fun part—loading the MNIST dataset. Luckily, this dataset is conveniently available in Keras, so it’s a breeze to load. The MNIST dataset contains 60,000 training images and 10,000 testing images. We’ll split this dataset into training images, training labels, test images, and test labels. But here’s the thing: the images come as arrays with pixel values ranging from 0 to 255. We’ll need to normalize these values to make them easier for the neural network to process. The goal is to scale the pixel values to a range between -1 and 1.

Here’s the code that does just that:


MNIST_DATA = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = MNIST_DATA.load_data()
print(f”Number of examples: {len(train_images)}”)
print(f”Shape of the images in the dataset: {train_images.shape[1:]}”)
train_images = train_images.reshape(train_images.shape[0], *IMG_SHAPE).astype(“float32”)
train_images = (train_images – 127.5) / 127.5

This code reshapes the training images so they match the expected input format for the neural network—28x28x1 arrays. Then, we normalize the pixel values by subtracting 127.5 and dividing by 127.5. This ensures the pixel values fall into the desired range of [-1, 1], which is perfect for neu

For a deeper dive into working with datasets and defining parameters in machine learning projects, check out the comprehensive guide on Loading and Preprocessing Data for TensorFlow Models.

Constructing The Generator Architecture

Building the generator architecture is like putting together the building blocks of a cool new creation. We’re going to use those upsampling blocks we defined earlier as the base to construct the generator model. This model will be the one responsible for generating images for our project, and it all starts with setting up the necessary parameters. For example, we’ve already decided on the latent noise dimension, which is crucial for our setup.

Now, let’s talk about the latent noise for a second. This is basically the starting point for our image generation process. You can think of it as the “ingredients” for creating a synthetic image. It’s a random vector that’s sampled from a Gaussian distribution, and it will serve as the input for generating all sorts of new data. To kick off the generator model, we’ll take this noise and pass it through a few layers. First, a fully connected (dense) layer, followed by batch normalization to keep things smooth and stabilize the training process. Then, we’ll throw in a Leaky ReLU activation to add some non-linearity and help the model learn more complex patterns.

Once the noise passes through those layers, we need to reshape the output into a 4x4x256 tensor, which serves as the starting feature map for the generator. This tensor will then be processed through a series of upsampling blocks. Each block increases the spatial resolution of the data—think of it as stretching the image to make it bigger. After passing through three of these blocks, we should end up with a 32×32 image. But, hold up! The MNIST dataset images are only 28×28, so we’ll need to crop the generated image using a Cropping2D layer to make it match the size of the MNIST images.

Here’s the code to define this process:


def get_generator_model():
    noise = layers.Input(shape=(noise_dim,))  # Input layer for the noise vector
    x = layers.Dense(4 * 4 * 256, use_bias=False)(noise)  # Fully connected layer
    x = layers.BatchNormalization()(x)  # Batch normalization to stabilize training
    x = layers.LeakyReLU(0.2)(x)  # Leaky ReLU activation</p>
<p>    # Reshaping the output into a 4x4x256 tensor
    x = layers.Reshape((4, 4, 256))(x)</p>
<p>    # Passing through the upsampling blocks
    x = upsample_block(x, 128, layers.LeakyReLU(0.2), strides=(1, 1), use_bias=False, use_bn=True, padding=”same”, use_dropout=False)
    x = upsample_block(x, 64, layers.LeakyReLU(0.2), strides=(1, 1), use_bias=False, use_bn=True, padding=”same”, use_dropout=False)
    x = upsample_block(x, 1, layers.Activation(“tanh”), strides=(1, 1), use_bias=False, use_bn=True)</p>
<p>    # Cropping the output to 28×28 dimensions
    x = layers.Cropping2D((2, 2))(x)</p>
<p>    # Defining the generator model
    g_model = keras.models.Model(noise, x, name=”generator”)</p>
<p>    return g_model</p>
<p>g_model = get_generator_model()  # Instantiate the generator model
g_model.summary()  # Display the summary of the generator model

In this code, we’ve defined the generator model using Keras layers. The first step processes the noise vector through a dense layer, which helps map it into a high-dimensional space. Then, the batch normalization layer comes in to keep everything stable, normalizing the activations across the layers. The Leaky ReLU activation introduces non-linearity, which allows the model to learn complex patterns better.

Next, the reshaped tensor goes through those upsampling blocks, which increase the image resolution. Each of these blocks includes an upsampling layer followed by a convolutional layer. The final output is passed through a tanh activation to ensure that the pixel values stay within the right range for generating realistic images.

Finally, we crop the image down to 28×28 pixels using the Cropping2D layer, making sure the output matches the size of the MNIST images. Once all that’s done, the generator model is ready to go, and it’s all set to create new images based on random noise input.

For a detailed understanding of building and training neural networks with Keras, check out the comprehensive guide on Building Generative Adversarial Networks with TensorFlow.

Constructing The Discriminator Architecture

Now that we’ve got the generator model all set up, it’s time to move on to the discriminator network. In the world of Wasserstein GANs (WGANs), we call this the “critic.” The job of the critic is simple but crucial: it has to figure out which images are real and which ones are fake—generated by the model.

Here’s the thing, though. The images in the MNIST dataset are 28×28 pixel grayscale images, but after passing through a few layers in the network, the dimensions get a bit tricky. So, to keep things smooth, we’re going to adjust the image size to 32×32. Why? Well, this ensures that after a couple of strides in the convolution layers, we don’t end up with uneven dimensions. To do this, we simply add a zero-padding layer at the beginning of the network, which helps us keep the dimensions intact during the convolution operation.

Once the image dimensions are good to go, we dive into the real action. We start by defining a series of convolutional blocks that will help the critic identify features from the input images—whether they’re real or generated. Each convolutional block is followed by a Leaky ReLU activation function, which helps add a little non-linearity to the mix. Batch normalization also gets added to keep everything stable during training. Oh, and some layers include dropout to prevent overfitting and help the model generalize better.

After four convolutional blocks, we flatten the output into a 1D vector and apply another dropout layer. Then, we add a dense layer to produce the final output—just one number that tells us whether the image is real or fake. But here’s the twist: unlike traditional GANs that use a sigmoid activation function to make this decision, the WGAN discriminator (or critic) outputs a continuous value. This is because we’re using the Wasserstein loss function, which works better with continuous values.

Let’s take a look at the code that defines this process:


def get_discriminator_model():   img_input = layers.Input(shape=IMG_SHAPE)  # Input layer for the image   x = layers.ZeroPadding2D((2, 2))(img_input)  # Padding to adjust the image dimensions   # First convolutional block   x = conv_block(x, 64, kernel_size=(5, 5), strides=(2, 2), use_bn=False, use_bias=True, activation=layers.LeakyReLU(0.2), use_dropout=False, drop_value=0.3)   # Second convolutional block   x = conv_block(x, 128, kernel_size=(5, 5), strides=(2, 2), use_bn=False, use_bias=True, activation=layers.LeakyReLU(0.2), use_dropout=True, drop_value=0.3)   # Third convolutional block   x = conv_block(x, 256, kernel_size=(5, 5), strides=(2, 2), use_bn=False, use_bias=True, activation=layers.LeakyReLU(0.2), use_dropout=True, drop_value=0.3)   # Fourth convolutional block   x = conv_block(x, 512, kernel_size=(5, 5), strides=(2, 2), use_bn=False, use_bias=True, activation=layers.LeakyReLU(0.2), use_dropout=False, drop_value=0.3)   x = layers.Flatten()(x)  # Flatten the output into a 1D vector   x = layers.Dropout(0.2)(x)  # Dropout to reduce overfitting   x = layers.Dense(1)(x)  # Dense layer to produce a single output   # Define the discriminator model   d_model = keras.models.Model(img_input, x, name=”discriminator”)   return d_model  # Instantiate the discriminator model</p>
<p>d_model = get_discriminator_model()d_model.summary()  # Display a summary of the discriminator model

In this code, we start with the input image and apply padding to adjust its dimensions. Then, we pass the image through four convolutional blocks. Each block helps the model learn more abstract features, with Leaky ReLU activations ensuring some information flows even when certain neurons aren’t activated. Dropout layers are included to prevent overfitting—especially in the layers with more filters.

Once the image goes through all the convolutional blocks, it’s flattened into a 1D vector. After that, we apply a dropout layer and pass the vector through a dense layer, which outputs a single value indicating whether the image is real or fake. This continuous output is ideal for the WGAN’s Wasserstein loss function.

Finally, we return the discriminator model and check out its summary using d_model.summary() . This gives us a quick overview of the model, including details like the number o_

For more on constructing and training neural network architectures with a focus on discriminators and generators, refer to the tutorial on Building Generative Adversarial Networks with TensorFlow.

Creating the overall WGAN model

Next up in building the Wasserstein GAN (WGAN) network is creating the overall structure of the model. We’re going to break the WGAN architecture into three main parts: the discriminator, the generator, and the training process. This breakdown will make it easier for us to see how everything fits together. Let’s get started by defining the parameters we’ll be using throughout the WGAN class. These parameters help us understand how they’re used within the different functions in the WGAN class. All the functions, including the creation of the generator and discriminator, will be housed inside the WGAN class itself. This class will extend Keras’ Model class, making it easy for us to build and compile the whole network.

Here’s the code that defines the core WGAN class:


class WGAN(keras.Model):
   def __init__(self, discriminator, generator, latent_dim, discriminator_extra_steps=3, gp_weight=10.0):
      super(WGAN, self).__init__()
      self.discriminator = discriminator
      self.generator = generator
      self.latent_dim = latent_dim
      self.d_steps = discriminator_extra_steps  # Number of times to train the discriminator per generator iteration
      self.gp_weight = gp_weight  # Gradient penalty weight   def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
      super(WGAN, self).compile()
      self.d_optimizer = d_optimizer  # Optimizer for the discriminator
      self.g_optimizer = g_optimizer  # Optimizer for the generator
      self.d_loss_fn = d_loss_fn  # Loss function for the discriminator
      self.g_loss_fn = g_loss_fn  # Loss function for the generator

In the __init__ method, we initialize the discriminator, generator, latent dimension, number of steps for the discriminator ( d_steps ), and the gradient penalty weight ( gp_weight ). Then, the compile method sets up the optimizers and loss functions for both the generator and the discriminator.

Gradient Penalty Method

Next up, let’s dive into the gradient penalty method. This method is really important because it ensures the Lipschitz continuity constraint for the WGAN, which helps keep things stable during training by making the gradients behave smoothly during backpropagation. The gradient penalty is calculated using an interpolated image, which is a mix of real and fake images. This penalty gets added to the discriminator loss. Here’s how we can implement the gradient penalty:


def gradient_penalty(self, batch_size, real_images, fake_images):
      # Get the interpolated image between real and fake images
      alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
      diff = fake_images – real_images
      interpolated = real_images + alpha * diff      # Compute the gradients of the discriminator with respect to the interpolated images
      with tf.GradientTape() as gp_tape:
         gp_tape.watch(interpolated)
         pred = self.discriminator(interpolated, training=True)      # Calculate the gradients
      grads = gp_tape.gradient(pred, [interpolated])[0]
      norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))      # Compute the gradient penalty
      gp = tf.reduce_mean((norm – 1.0) ** 2)
      return gp

In this function:

  • We generate a random alpha value between 0 and 1 to mix the real and fake images.
  • The interpolated image is created by blending the real and fake images.
  • We use GradientTape to calculate the gradients of the discriminator’s prediction based on the interpolated image.
  • After calculating the gradients, we find the norm and compute the gradient penalty by squaring the difference from 1.

Training Step Method

Now for the final step—defining the training method. This function alternates between training the generator and the discriminator, and here’s how it works:

  • We train the discriminator for a set number of steps ( d_steps ).
  • We compute the losses for both the discriminator and the generator.
  • We calculate and apply the gradient penalty to the discriminator’s loss.

Here’s the code for the train_step method:


def train_step(self, real_images):
   if isinstance(real_images, tuple):
      real_images = real_images[0]
   batch_size = tf.shape(real_images)[0]
   for i in range(self.d_steps):
         # Generate random latent vectors for the generator
         random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
         # Train the discriminator on fake images
         with tf.GradientTape() as tape:
           fake_images = self.generator(random_latent_vectors, training=True)
           fake_logits = self.discriminator(fake_images, training=True)
           real_logits = self.discriminator(real_images, training=True)           # Calculate discriminator loss using the real and fake image logits
           d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
           # Calculate the gradient penalty and add it to the discriminator loss
           gp = self.gradient_penalty(batch_size, real_images, fake_images)
           d_loss = d_cost + gp * self.gp_weight           # Compute gradients with respect to the discriminator loss
           d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
           self.d_optimizer.apply_gradients(zip(d_gradient, self.discriminator.trainable_variables))   # Train the generator
   random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
   with tf.GradientTape() as tape:
      generated_images = self.generator(random_latent_vectors, training=True)
      gen_img_logits = self.discriminator(generated_images, training=True)
      # Calculate the generator loss
      g_loss = self.g_loss_fn(gen_img_logits)   # Compute gradients with respect to the generator loss
   gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
   self.g_optimizer.apply_gradients(zip(gen_gradient, self.generator.trainable_variables))   return {“d_loss”: d_loss, “g_loss”: g_loss}

In this function:

  • First, we train the discriminator multiple times ( d_steps ).
  • For each step, we generate fake images and get the logits for both real and fake images using the discriminator.
  • The discriminator’s loss is then calculated, and we add the gradient penalty.
  • After that, the generator is trained by generating new fake images and calculating the loss using the discriminator’s logits for those fake images.
  • Finally, we compute the gradients for both the generator and the discriminator, and apply those gradients to update their weights.

This setup ensures that both the generator and discriminator are trained in line with the principles of Wasserstein GANs, making it possible for the model to generate high-quality images over time.

For a deeper dive into WGANs and their implementation, check out this comprehensive guide on Creating and Training a Wasserstein GAN.

Training the model

Alright, so now we’re at the final stretch of building the WGAN (Wasserstein Generative Adversarial Network) model, and it’s time to train the thing to generate high-quality results. We’ll break this down into a few key steps. First, we need to create a custom callback for the WGAN model, which is going to allow us to save the generated images as we train. This will help us track how things are going and give us a way to see the progress of our generator at different stages of the training process.

Here’s the code that shows how to create that callback:


class GANMonitor(keras.callbacks.Callback):
    def __init__(self, num_img=6, latent_dim=128):
        self.num_img = num_img  # Number of images to generate and save per epoch
        self.latent_dim = latent_dim  # Latent space dimension  </p>
<p>    def on_epoch_end(self, epoch, logs=None):
        # Generate random latent vectors
        random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
        # Generate images from the latent vectors using the model’s generator
        generated_images = self.model.generator(random_latent_vectors)
        # Scale the generated images back to the range [0, 255]
        generated_images = (generated_images * 127.5) + 127.5
        # Save the generated images
        for i in range(self.num_img):
            img = generated_images[i].numpy()
            img = keras.preprocessing.image.array_to_img(img)  # Convert the array to an image format
            img.save(f”generated_img_{i}_{epoch}.png”)  # Save the image with epoch number in the filename

This callback will generate a set of images after each epoch and save them as PNG files. It uses random latent vectors to create the images and scales the pixel values back to the [0, 255] range before saving them. This way, we can visually track how the generator is doing as it trains.

Setting Up Optimizers and Loss Functions

Next up, we need to define the optimizers and loss functions for both the generator and the discriminator. For this, we’re using the Adam optimizer, which is pretty popular when it comes to training GANs, including WGANs. The hyperparameters like learning rate and momentum values are picked based on best practices mentioned in the WGAN research paper.

Here’s how we set up the optimizers and loss functions:


# Defining optimizers for both generator and discriminator
generator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9)
discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9)  # Discriminator loss function
def discriminator_loss(real_img, fake_img):
    # The discriminator aims to correctly classify real and fake images
    real_loss = tf.reduce_mean(real_img)
    fake_loss = tf.reduce_mean(fake_img)
    return fake_loss – real_loss  # WGAN discriminator loss formula  # Generator loss function
def generator_loss(fake_img):
    # The generator aims to fool the discriminator into classifying fake images as real
    return -tf.reduce_mean(fake_img)  # WGAN generator loss formula

The discriminator’s loss function calculates the difference between the average values (or logits) for real and fake images. The goal is to maximize that difference so the discriminator can better tell the difference between real and fake. On the flip side, the generator’s loss function tries to minimize the average value for the fake images, helping the generator make better images that resemble real ones.

Model Training Process

Now that we have the optimizers and loss functions set up, it’s time to instantiate the WGAN model and get it ready for training. We’re going to train the model for a total of 20 epochs, but feel free to adjust this based on how much time and computational resources you have.

Here’s the code to kick off the training:


epochs = 20  # Define the number of epochs for training  # Instantiate the custom callback
cbk = GANMonitor(num_img=3, latent_dim=noise_dim)  # Instantiate the WGAN model
wgan = WGAN(discriminator=d_model, generator=g_model, latent_dim=noise_dim, discriminator_extra_steps=3)  # Compile the WGAN model with the defined optimizers and loss functions
wgan.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer,
    g_loss_fn=generator_loss,
    d_loss_fn=discriminator_loss
)  # Start the training process
wgan.fit(train_images, batch_size=BATCH_SIZE, epochs=epochs, callbacks=[cbk])

In this snippet:

  • We define the number of epochs (20).
  • We instantiate the callback, which will generate and save images periodically.
  • We create the WGAN model by passing in the discriminator and generator models (d_model and g_model).
  • We compile the model by specifying the optimizers (discriminator_optimizer and generator_optimizer) and loss functions (discriminator_loss and generator_loss).
  • Finally, we train the model using fit(), passing the training images, batch size, and number of epochs.

Evaluating the Model

Once we’ve trained the model for the specified number of epochs, we should start to see some pretty convincing images that look a lot like the real MNIST digits. This shows us that the WGAN architecture is working its magic.

Below is an example of the kind of images you should see after a few epochs of training.

Even though the results might look good after just a few epochs, you’ll get even better images with more training. I mean, who doesn’t like a bit more fine-tuning, right? If you’ve got the time and resources, I highly recommend running the model for more epochs to really let it shine. The longer you train it, the more realistic the generated images will become. You’ll definitely see the generator getting better at mimicking the true MNIST distribution.

To further explore how to fine-tune your model training, check out this in-depth tut

Conclusion

In conclusion, Wasserstein GANs (WGANs) offer a powerful advancement in generative adversarial networks by addressing issues like mode collapse and unstable training. By utilizing the Wasserstein distance and incorporating techniques such as weight clipping and gradient penalties, WGANs provide smoother training, resulting in higher-quality output in image, audio, and text generation. These innovations ensure more reliable and efficient generative models, making WGANs a go-to solution for many fields. As the field of AI continues to evolve, WGANs are poised to play a key role in driving the next generation of high-quality generative models.In the future, we can expect further refinements in WGAN architectures that continue to enhance model performance and expand their applicability in various creative and technical industries.

Master StyleGAN1 Implementation with PyTorch and WGAN-GP

Any Cloud Solution, Anywhere!

From small business to enterprise, we’ve got you covered!

Caasify
Privacy Overview

This website uses cookies so that we can provide you with the best user experience possible. Cookie information is stored in your browser and performs functions such as recognising you when you return to our website and helping our team to understand which sections of the website you find most interesting and useful.