This notebook shows the training pipeline of a Variational Auto-Encoder(VAE) in Kokoyi. VAE is an important class of generative model. For example, after training on a collection of real images, the model can generate new, look-alike images.
Variational Auto-Encoder(VAE) is one of the most popular architectures in the field of generative models (GAN being another). Standard auto-encoder can compress a high-dimensional data to a latent code (e.g. the "so-called" bottleneck layer in at the output of the encoder) but it is deterministic. As such, they have limited power to generalize to unseen data that follows the same distribution out of which the training set is drawn.
The core idea of VAE is to compress the input $x$ onto a multivariate Gaussian variable $z$, i.e. $q(z|x) = \mathcal{N}(\pmb{\mu}, \pmb{\sigma}^2\mathbf{I})$. You can alternatively think $z$ is a stack of independent Gaussian variable, each has its own mean and variance. To generate a new sample of $x$ we sample $z$ and push it through the decoder.
The rigorous derivation of VAE is interesting (by maximizing the probability of $x$) and there is a fair amount of good materials (c.f. post). But it can also be intuitively understood if we take a look at the components of the total loss:
The combination gives us the ELBO loss: $$ \mathcal{L}(\theta, \phi; \mathbf{x}) = loss_{KL} + loss_{reconstruction} = -D_{KL}(q_\phi(\mathbf{z}|\mathbf{x}), p(\mathbf{z})) + \mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}[\log p_\theta(\mathbf{x}|\mathbf{z})]$$
We will now start writing the model in Kokoyi. First, we define the sublayer for the encoder and the decoder. Since in this example we are generating images, we use some normalization and convolutional modules.
%kokoyi
\Module {ConvBlock} {x; Conv2d, BatchNorm2d}
\Return \ReLU(BatchNorm2d(Conv2d(x))) \\
\EndModule
\Module {TransposedConvBlock} {x; ConvTranspose2d, BatchNorm2d}
\Return \ReLU(BatchNorm2d(ConvTranspose2d(x))) \\
\EndModule
The encoder uses convolution layers $ConvBlock$ to compute the mean $\pmb{\mu}$ and s.d. $\pmb{\sigma}$ from input images $\mathbf{x}$. We always compute log variance $logvar$ for s.d $\pmb{\sigma}$.
%kokoyi
\Module {Encoder} {x; ConvBlocks, Linears}
L \gets |ConvBlocks| \\
h[0 \leq i \leq L] \gets \begin{cases}
x & i = 0 \\
ConvBlocks[i-1](h[i-1]) & otherwise \\
\end{cases} \\
\hat{h} \gets \Flatten(h[L]) \\
(Linear_\mu, Linear_v) \gets Linears \\
(\mu, v) \gets (Linear_\mu(\hat{h}), Linear_v(\hat{h})) \\
\Return (\mu, v) \\
\EndModule
The decoder uses transposed convolution layers $TransposedConvBlock$ to generate images $\mathbf{\hat{x}}$ from the latent vector $\mathbf{z}$. You may find the decoder has the same structure as that of generator in GAN.
%kokoyi
\Module {Decoder} {z; C, H, W, Linear, TransposedConvBlocks, Conv2d}
\hat{z} \gets Linear(z) \\
L \gets |TransposedConvBlocks| \\
h[0 \leq i \leq L] \gets \begin{cases}
\Reshape(\hat{z}, (C, H, W)) & i = 0 \\
TransposedConvBlocks[i-1](h[i-1]) & otherwise \\
\end{cases} \\
\hat{x} \gets \tanh(Conv2d(h[L])) \\
\Return \hat{x} \\
\EndModule
As mentioned above, the decoder takes the latent vector $\mathbf{z}$ as input, so we need to sample $\mathbf{z}$ from $ q_\phi(\mathbf{z}|\mathbf{x}) = \mathcal{N}(\pmb{\mu}, \pmb{\sigma}^2\mathbf{I})$. Gradient cannot pass through discrete operations such as sampling, and this is solved by the reparameterization trick (see a post here and also from the paper), i.e. it suffices to let the gradient modifies the mean and variance of $z$:
$$ \mathbf{z} = \pmb{\sigma} \odot \pmb{\epsilon} + \pmb{\mu} $$where $\pmb{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ and $\odot$ is the element-wise product. We will use *
as the element-wise product operator, which displays as $\sigma \circ \epsilon$.
%kokoyi
\Function {Reparameterize} {\mu, v}
\sigma \gets \Exp(0.5 * v) \\
\epsilon \gets \Rand(|\sigma|) \\
\Return \sigma * \epsilon + \mu \\
\EndFunction
class ConvBlock(torch.nn.Module): def __init__(self): """ Add your code for parameter initialization here (not necessarily the same names).""" super().__init__() self.Conv2d = None self.BatchNorm2d = None def get_parameters(self): """ Change the following code to return the parameters as a tuple in the order of (Conv2d, BatchNorm2d).""" return None forward = kokoyi.symbol["ConvBlock"]
class TransposedConvBlock(torch.nn.Module): def __init__(self): """ Add your code for parameter initialization here (not necessarily the same names).""" super().__init__() self.ConvTranspose2d = None self.BatchNorm2d = None def get_parameters(self): """ Change the following code to return the parameters as a tuple in the order of (ConvTranspose2d, BatchNorm2d).""" return None forward = kokoyi.symbol["TransposedConvBlock"]
class Encoder(torch.nn.Module): def __init__(self): """ Add your code for parameter initialization here (not necessarily the same names).""" super().__init__() self.ConvBlocks = None self.Linears = None def get_parameters(self): """ Change the following code to return the parameters as a tuple in the order of (ConvBlocks, Linears).""" return None forward = kokoyi.symbol["Encoder"]
class Decoder(torch.nn.Module): def __init__(self): """ Add your code for parameter initialization here (not necessarily the same names).""" super().__init__() self.C = None self.H = None self.W = None self.Linear = None self.TransposedConvBlocks = None self.Conv2d = None def get_parameters(self): """ Change the following code to return the parameters as a tuple in the order of (C, H, W, Linear, TransposedConvBlocks, Conv2d).""" return None forward = kokoyi.symbol["Decoder"]
Here's the completed module definitions:
import torch
from kokoyi.nn import Linear, Conv2d, ConvTranspose2d, BatchNorm2d
class ConvBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.Conv2d = Conv2d(in_channels, out_channels, 3, stride=2, padding=1)
self.BatchNorm2d= BatchNorm2d(out_channels)
def get_parameters(self):
return self.Conv2d, self.BatchNorm2d
forward = kokoyi.symbol["ConvBlock"]
class TransposedConvBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.ConvTranspose2d = ConvTranspose2d(in_channels, out_channels, 3, stride=2, padding=1, output_padding=1)
self.BatchNorm2d = BatchNorm2d(out_channels)
def get_parameters(self):
return self.ConvTranspose2d, self.BatchNorm2d
forward = kokoyi.symbol["TransposedConvBlock"]
class Encoder(torch.nn.Module):
def __init__(self, in_channels, latent_dim):
super().__init__()
self.ConvBlocks = torch.nn.ModuleList([
ConvBlock(in_channels, 32),
ConvBlock(32, 64),
ConvBlock(64, 128),
ConvBlock(128, 256),
ConvBlock(256, 512)
])
self.Linears = torch.nn.ModuleList([
Linear(512*4, latent_dim),
Linear(512*4, latent_dim)
])
def get_parameters(self):
return self.ConvBlocks, self.Linears
forward = kokoyi.symbol['Encoder']
class Decoder(torch.nn.Module):
def __init__(self, C, H, W, out_channels, latent_dim):
super().__init__()
self.C = C
self.H = H
self.W = W
self.Linear = Linear(latent_dim, 512 * 4)
self.TransposedConvBlocks = torch.nn.ModuleList([
TransposedConvBlock(512, 256),
TransposedConvBlock(256, 128),
TransposedConvBlock(128, 64),
TransposedConvBlock(64, 32),
TransposedConvBlock(32, 32)
])
self.Conv2d = Conv2d(32, out_channels, 3, padding=1)
def get_parameters(self):
return self.C, self.H, self.W, self.Linear, self.TransposedConvBlocks, self.Conv2d
forward = kokoyi.symbol['Decoder']
VAE has two loss terms. For the KL loss, both the prior $p(\mathbf{z})$ and $q_\phi(\mathbf{z}|\mathbf{x})$ are Gaussian and the loss can be written out how to compute $loss_{KL}$:
$$ \begin{align} loss_{KL} = D_{KL}(q_\phi(\mathbf{z}|\mathbf{x}), p(\mathbf{z})) &= -\frac{1}{2}(1 + \log {(\pmb{\sigma}^2)} - \pmb{\mu}^2 - \pmb{\sigma}^2) \\ &=-\frac{1}{2}(1 + logvar - \pmb{\mu}^2 - e^{logvar}) \end{align}$$For the reconstruction loss, we sample $\mathbf{z}$ from the encoder's output and apply the decoder to reconstruct the image $\mathbf{\hat{x}}$ from $\mathbf{z}$. The mean squared error (squared L2 norm) between reconstructed images $\mathbf{\hat{x}}$ and input images $\mathbf{x}$ is considered as the $loss_{reconstruction}$. Hence, the loss can be written in the following form:
$$ loss = -\frac{\lambda}{2}(1 + logvar - \pmb{\mu}^2 - e^{logvar}) + \|\hat{x} - x\|^2 $$<!--The loss for minibatch training is:
$$ loss = -\frac{w}{2M}\sum_{i=0}^{M}(1 + logvar_{M} - \pmb{\mu}_{M}^2 - e^{logvar_{M}}) + \text{MSELoss}(\hat{x}, x) $$where $B$ is the batch size and $\lambda$ is the weight of $loss_{KL}$ which serves to control the balance between two loss terms.--> where $\lambda$ is the weight of $loss_{KL}$ which serves to control the balance between two loss terms. Note how we can simply write $\|\hat{x} - x\|$ for the L2 norm and square it.
%kokoyi
\Function {Loss} {x, Encoder, Decoder, Reparameterize, \lambda}
(\mu, v) \gets Encoder(x) \\
z \gets Reparameterize(\mu, v) \\
\hat{x} \gets Decoder(z) \\
J \gets |\mu| \\
loss_{KL} \gets \sum_{j=0}^{J-1}{-0.5 * (1 + v[j] - \mu[j] ** 2 - \exp(v[j]))} \\
loss_{rec} \gets \|\hat{x} - x\| ** 2 \\
\Return (\lambda * loss_{KL} + loss_{rec}, loss_{KL}, loss_{rec}) \\
\EndFunction
Finally, we can set the hyper-parameters and start training!
For every 50 steps, we will reconstruct/sample a batch of images and save them to the disk. Just to have some fun, we provide two generation functions in Kokoyi: one for reconstructing images from the encoder's output, one for sampling from the latent space. Note how we access them from PyTorch with kokoyi.symbol['Reconstruction']
and kokoyi.symbol['Sample']
.
%kokoyi
\Function {Reconstruction} {x, Encoder, Decoder, Reparameterize}
(\mu, v) \gets Encoder(x) \\
z \gets Reparameterize(\mu, v) \\
\Return Decoder(z) \\
\EndFunction
\Function {Sample} {Decoder, z}
\Return Decoder(z) \\
\EndFunction
import os
import torch
from torch import optim
from torch.nn import functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import kokoyi
The CelebA dataset from torchvision is used to train the model. The dataset consists of more than 200K celebrity images, each image will be cropped at the center and resize to $64 \times 64 $. You can skip the code cells below if you're not interested in data preparation.
img_size = 64
batch_size = 144
os.makedirs('data', exist_ok=True)
os.makedirs('vae/sample', exist_ok=True)
os.makedirs('vae/recons', exist_ok=True)
SetRange = transforms.Lambda(lambda X: 2 * X - 1.)
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(148),
transforms.Resize(img_size),
transforms.ToTensor(),
SetRange
])
celeba = torchvision.datasets.CelebA(root='data', split='train', transform=transform, download=True)
celeba_test = torchvision.datasets.CelebA(root='data', split='test', transform=transform, download=False)
num_train = len(celeba)
dataloader = DataLoader(celeba, batch_size=batch_size, shuffle=True, drop_last=True)
test_dataloader = DataLoader(celeba_test, batch_size=batch_size, shuffle=False, drop_last=True)
num_batch = len(dataloader)
if torch.cuda.is_available():
device_name = 'cuda:0'
else:
device_name = 'cpu'
print('Using device: ', device_name)
device = torch.device(device_name)
kokoyi.set_rt_device(device)
Now the full training loop:
num_epochs = 30
channels = 3
latent_dim = 128
lr = 0.005
C, H, W = 512, 2, 2
encoder = Encoder(channels, latent_dim).to(device)
decoder = Decoder(C, H, W, channels, latent_dim).to(device)
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()))
test_imgs = next(iter(test_dataloader))[0].to(device)
w = batch_size / num_train
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(dataloader):
imgs = imgs.to(device)
optimizer.zero_grad()
reparameter = kokoyi.symbol['Reparameterize']
loss, kl_loss, reconstrution_loss = kokoyi.symbol['Loss'](imgs, encoder, decoder, reparameter, w, batch_level=[1,0,0,0,0])
kl_loss = torch.mean(kl_loss)
reconstrution_loss = torch.mean(reconstrution_loss)
loss = torch.mean(loss)
loss.backward()
optimizer.step()
if i % 50 == 0:
print('[%d/%d][%d/%d] Loss: %.4f KL_loss: %.4f Reconstrution_loss: %.4f' % (epoch, num_epochs, i, num_batch, loss.item(), -kl_loss.item(), reconstrution_loss.item()))
with torch.no_grad():
if i % 50 == 0:
recons = kokoyi.symbol['Reconstruction'](test_imgs, encoder, decoder, reparameter, batch_level=[1,0,0,0])
save_image(recons.cpu().data, f"vae/recons/{epoch}_{i}.png", normalize=True, nrow=12)
z = torch.randn(batch_size, latent_dim).to(device)
samples = kokoyi.symbol['Sample'](decoder, z, batch_level=[0,1])
save_image(samples.cpu().data, f"vae/sample/{epoch}_{i}.png", normalize=True, nrow=12)