Diffusion Models Using a Single Equation
01 Dec 2021 | generative modeling denoising diffusion DDIM DDPMIntroduction
Recently, denoising diffusion models Ho et al. (2020) and noise conditional score networks Song et al. (2019) have been shown to be a powerful class of generative models, that can rival even generative adversarial networks (GANs) Dhariwal & Nicol (2021) in image synthesis quality, while being more stable to train than GANs Goodfellow et al. (2014). One of their drawback is however, that they are slower to sample from, because they require multiple, sometimes even hundreds of forward passes per generated image Ho et al. (2020).
Motivation
A difficulty I faced when learning about diffusion models was that that the entry barrier to understanding them is quite high, their theoretical background is mathematically convoluted and can be difficult to follow, sometimes even with conflicting notations accross works. My aim with this blogpost is to lower this barrier, to make diffusion models, specifically Denosing Diffusion Implicit Models (DDIMs) Song et al. (2020) more accessible by providing an explanation of their inner workings that is mathematically very light (only uses a single equation, and carefully named variables, no letter notations), and try to understand them from a different point of view compared to their paper. I would also like to give the readers novel intuitions about these models, and to provide an end-to-end, simple to use implementation for diffusioin models, that is easy to customize and can be a good starting point for future projects on this subject.
Intuitions
A well known intuition is that what these models actually do, is that they estimate the data distribution Song (2021), and that they can generate from these distibutions by following the gradient of the log probability density function of the data.
In this work my aim is to present a different, more practical intuition: we force these models to sample from the data distribution by misleading them.
Diffusion Process
Let us imagine a picture, which over time gets mixed with more and more noise, so much so, that after sufficiently long time it is indistuinguishable from pure pixel noise. This is a diffusion process, where at the start (at diffusion time = 0), all of our current signal consists of the original image, and as time goes by, all of the signal becomes noise (at diffusion time = 1).
We model this process by describing it with a diffusion schedule, which maps diffusion timestamps to corresponding signal rates and noise rates, which describe the power ratio in the signal processing sense (assuming, that they both have the same power in themselves) of the signal and noise respectively, in comparison to the combined noisy signal.
signal_rate = 1 - noise_rate
noise_rate = 1 - signal_rate
Mixing Equation
The equation that we will use is this work is the following, the mixing equation:
noisy_images = signal_rates ** 0.5 * images + noise_rates ** 0.5 * noises
It describes how the signal gets combined with noise, creating a noisy signal. In this case I consider images as the signal, and pixel-noise as the noise, as this is what we will be dealing with in this work.
Algorithm
In the following sections, I will use simplified Keras-based code snippets to present the algorithm of a Denoising Diffusion Implicit Model Song et al. (2020). Note that these code snippets would not run in themselves, as I have omitted some variable and method declarations, function arguments and shape manipulations for clarity. At the end of this blogpost I present a feature-complete end-to-end implementation, which can be used for experimentation.
This is the model class that we will use for implementing the algorithm:
class DiffusionModel(keras.Model):
def __init__(self, diffusion_steps, time_margin):
super().__init__()
self.diffusion_steps = diffusion_steps
self.time_margin = time_margin
self.network = build_network()
def compile(self, optimizer):
super().compile()
self.optimizer = optimizer
Task
Diffusion models are trained to solve the task of signal denoising, or more precisely signal-noise separation. Their input is a noisy signal and additional information about the signal and noise rates used to produce it, and they have to predict the original signal and noise values that we mixed together.
There are two ways to solve this task, making sure that the outputs obey to the mixing equation:
- The network could either predict the original signal which we can substitute into the mixing equation to calculate the corresponding noise that could have created the noisy input signal when being mixed with the predicted signal.
- It can also equivalently predict the noise that was mixed with the original signal, and we can use the mixing equation to similarly calculate the corresponding original signal.
Though theoretically both these solutions are equally sound, in practice we usually use the latter one, predicting the noise, as empirically it leads to a more stable training and higher generation quality Ho et al. (2020), but note that I investigate the swapped task of predicting the original images instead of noise in one of the ablations.
So in practice, based on the noisy image and the used signal rate, we predict the noise and simply rearrange the mixing equation to get the predicted original image based on the predicted noise:
def denoise(self, noisy_images, signal_rates, noise_rates):
pred_noises = self.network([noisy_images, signal_rates])
pred_images = signal_rates ** -0.5 * (noisy_images - noise_rates ** 0.5 * pred_noises)
return pred_images, pred_noises
We can see that if a signal rate of 0 would be used, we would have a division by zero at these lines, but in that case the task would also be ill-posed too, as it is impossible to predict the original signal if the noisy signal does not contain any of it, only pure noise. To avoid this issue, I will use time margins when sampling diffusion times, to stay away of the to limits of the process, where our quantities might blow up.
Training
How can we train these diffusion models?
- For each training sample, we sample a normal-distributed noise, with the same shape as the images. These will later be mixed to create the noisy signal, and with that the input-output pairs for training these models.
def train_step(self, images): noises = tf.random.normal(shape=(batch_size, image_size, image_size, image_channels))
- We define the diffusion schedule, which maps a uniformly distributed variable, the diffusion time, to the corresponding signal and noise rates, then we use it to sample a signal and noise rate for each training sample.
diffusion_times = tf.random.uniform(shape=(batch_size,), minval=self.time_margin, maxval=1.0 - self.time_margin) signal_rates, noise_rates = self.diffusion_schedule(diffusion_times)
- Finally, we use these rates according to the mixing equation to mix the training samples with noise, creating the noisy signals.
noisy_images = signal_rates ** 0.5 * images + noise_rates ** 0.5 * noises
- These noisy signals are then fed into the network, which tries to separate them into the original signal and noise as described in the previous section.
with tf.GradientTape() as tape: pred_images, pred_noises = self.denoise(noisy_images, signal_rates)
- We then use the (true noise, predicted noise) pairs to calculate a reconstruction loss for each training sample. Theoretically, mean squared error (MSE) should be used here, however in practice mean absolute error (MAE) seems to produce better results Chen & Zhang et al. (2020). While MSE seems to lead to more diverse outputs, MAE produces more conservative ones Saharia et al. (2021), which lines up with my experience as well, so this is what I use here.
noise_loss = keras.losses.mean_absolute_error(noises, pred_noises)
- The loss gets backpropagated, and gradient-based optimization is applied on the network’s weights.
gradients = tape.gradient(noise_loss, self.network.trainable_weights) self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights)) return noise_loss
Sampling
Now, how can we turn these models into generative models? How can we utilize this denoising behaviour for generating from the data distribution?
My intuition, and the main point of this work is, that we lie to them, we mislead them, as I will explain in this section.
We sample normally distributed noise, and make the network iteratively denoise it, making it hallucinate a realistic signal by denoising pure noise. So the network is used to estimate a reverse diffusion process from pure noise. This is done via an iterative process, where in each step we use the network to reduce a small amount of noise in the signal, using signal and noise rates given by diffusion times that slowly move back in time, decreasing from almost 1 (completely noise) to almost 0 (completely signal).
But what do we do in the first step?
Recall from the section describing the task that we cannot have a signal rate of exactly zero, as we would run into a division by zero during the denoising step, and also the task would be ill-defined. But in the first step, we have a signal rate of exactly zero.
To resolve this issue, we trick the network, and by telling it that there is a small amount of signal even in the pure noise, by inputing pure noise as the noisy signal, while setting the signal rate slightly above zero, and the noise rate slightly below one. By using time margins, the initial diffusion time will be above zero, which will cause the diffusion schedule to also output a non-zero value as the starting signal rate.
def diffusion_process(self, initial_noise):
noisy_images = initial_noise
diffusion_times = tf.linspace(1 - self.time_margin, self.time_margin, self.diffusion_steps + 1)
for step in range(self.diffusion_steps):
signal_rates, noise_rates = self.diffusion_schedule(diffusion_times[step])
pred_images, pred_noises = self.denoise(noisy_images, signal_rates, noise_rates)
But what should be the noisy signal in the following step?
Since we have a latest estimate of the signal and noise in the current step, we can recombine these using the signal and noise rate of the following step, to get our best estimate of what the noisy signal would be if it had a slightly different signal-to-noise ratio.
Iterating these steps from a diffusion time of almost one (pure noise) to almost zero (pure signal), will utilize the network to gradually denoise pure noise into something that it considers a real signal.
next_signal_rates, next_noise_rates = self.diffusion_schedule(diffusion_times[step + 1])
noisy_images = next_signal_rates ** 0.5 * pred_images + next_noise_rates ** 0.5 * pred_noises
return pred_images
And with that, we arrive at a sampling procedure that is equivalent to the one described in the Denoising Diffusion Implicit Models paper, but based on very different considerations.
Network
Ho et al. (2020) proposes the usage of a U-Net network Ronneberger et al. 2015, which makes our network an overcomplete denoising autoencoder, overcomplete because its latent dimensionality is higher than that of the data, denoising because it tries to reconstruct the original data from its corrupted inputs.
During my experiments I found the following 3 properties of the neural network architecture to be the most important:
- the network should be a U-Net, i.e. it should downsample and the upsample the input data, while also containing skip connections from its first half of layers to the layers in its second half with the same resolution.
- each stage of the network (a contiguous set of layers that operate on the same resolution) should consist of residual blocks, so the flow of information in the network should not only be helped with large skip, but with small residual connections as well.
- The signal rates should be embedded, using a sinusoidal embedding layer, which is known as positional encoding in Transformers Vaswani et al. (2017) and Neural Radiance Fields Mildenhall & Srinivasan & Tancik et al. (2020).
Also, I did run into occasional diverged trainings, expecially when increasing the network’s size, which was suprising based on the simplicity of the training procedure. I have found that the following methods help with training stability:
- weight decay (using AdamW Loshchilov et al. (2017) instead of Adam Kingma and Ba (2014) )
- layer normalization Ba et al. (2016)
- batch normalization Ioffe et al. (2015) In the reference implementation at the end of this blogpost I use a combination of weight decay and batch normalization.
The following is a simplified implementation of the recommended neural network architecture, using the Keras Functional API:
def build_network():
images = keras.Input(shape=(image_size, image_size, image_channels))
signal_rates = keras.Input(shape=(1,))
signal_rate_embeddings = layers.Lambda(sinusoidal_embedding)(signal_rates)
x = layers.Concatenate()([images, signal_rate_embeddings])
skips = [None] * depth
for i in range(depth):
x, skips[i] = DownStage(residual=True)(x)
for i in reversed(range(depth)):
x = UpStage(residual=True)([x, skips[i]])
output_signal = layers.Conv2D(image_channels, kernel_size=1)(x)
return keras.Model([images, signal_rates], output_signal, name="unet")
Why is the sinusoidal embedding a crucial part of the architecture? What is its role?
To sum it up, it helps the network be highly sensitive to its value (I recommend Tancik & Srinivasan & Mildenhall et al. (2020) to interested readers.), which is useful, as theory suggests that ideally separate networks should be used for each denoising step, which is not realistic in practice.
The following code snippet shows a minimalistic implementation of sinusoidal embeddings:
def sinusoidal_embedding(x):
log_frequencies = tf.linspace(tf.math.log(min_frequency), tf.math.log(max_frequency), num_frequencies)
frequencies = tf.exp(log_frequencies)
angular_speeds = 2 * math.pi * frequencies
embeddings = tf.concat([tf.sin(angular_speeds * x), tf.cos(angular_speeds * x)])
return embeddings
Stochastic Sampling
Even though historically Denoising Diffusion Probabilistic Models (DDPMs) preceeded the implicit models (DDIMs), following the line of reasoning proposed in this work, they can be interpreted as an extension of DDIMs. I should note here, the training DDPMs is the exact same procedure as training DDIMs, the only real difference between the two is the sampling procedure.
Though I did not manage to derive exactly the DDPM sampling procedure from the type of reasoning presented here, I can still provide some intuitions of how I think DDPM models work by dissecting their sampling procedure into smaller and easier-to-reason-about parts.
What is a common issue with autoregressive generative models, that use their own previous outputs as inputs over an iterative procedure?
If the quality of one of their outputs turns out to be suboptimal, in the next step, since the input is not contained in the training data distribution, the quality of the next output might get even worse, and generative model can quickly wander of the distribution where it works reliably, an can start generating low quality outputs.
How could we counteract that in our sampling procedure?
Let us take advantage of the fact that the sum of two normally distributed random variables is also normally distributed (if they are independent)! This means, that if the predicted noise is normally distributed, we remain in distribution if we add some extra normally distributed noise, provided that we slightly downscale predicted noise, so that the resulting noise components noise rate changes as implenented. If the predicted noise was not normally distributed, my intuition is that it gets pushed slightly closer to a normal distribution, which helps with increasing sample quality.
This modified training procedure might improve sample quality as described above, but one of its downsides is that using it, the sampling process becomes stochastic, while the DDIM sampling procedure was fully deterministic.
def diffusion_process(self, initial_noise):
diffusion_times = tf.linspace(1.0 - self.time_margin, self.time_margin, self.diffusion_steps + 1)
noisy_images = initial_noise
for step in range(self.diffusion_steps):
signal_rates, noise_rates = self.diffusion_schedule(diffusion_times[step])
next_signal_rates, next_noise_rates = self.diffusion_schedule(diffusion_times[step + 1])
pred_images, pred_noises = self.denoise(noisy_images, signal_rates, noise_rates)
The following lines present an implementation of DDPM sampling in this framework.
- we generate extra noise, and also calculate its rate to be a small value
- then we also calculate a noise rate multiplier which should be slightly below one
- finally we mix the elements but based on a modified version of the single equation, where the noise rate is decrasead while an additional extra noise rate is added to estimate the next noisy image
extra_noises = tf.random.normal(shape=(batch_size, image_size, image_size, 3)) extra_noise_rates = 1.0 - signal_rates / next_signal_rates noise_rate_multipliers = 1.0 - extra_noise_rates / noise_rates noisy_images = ( next_signal_rates ** 0.5 * pred_images + (next_noise_rates * noise_rate_multipliers) ** 0.5 * pred_noises + extra_noise_rates ** 0.5 * extra_noises ) return pred_images
Applications
The fact that the DDIM sampling procedure is deterministic lands it to some interesting use cases.
Noise Space Interpolation
One can carry out interpolation between two images in noise space space by starting calculating intermediate points between their starting noise values.
Encoding and Decoding Images
We can iteratively encode images into pure noise using a forward diffusion process, and then decode them and compare the reconstructions with the original images. On this figure the top row contains the original images, the middle one their encoded final noise values, and the bottom one their reconstructions, using these noises as starting point of a reverse diffusion process.
Ablations
Baseline
The following is the baseline image generation quality, following exactly the settings found in the end-to-end implementation at the end of this blogpost (training takes around an hour on a single A100 GPU).
Training Ablations
Predicting Images Instead of Noise
In my experience, the version of diffusion models that predict the images instead of noise, are tougher to train, with complete and temporary divergence events being much more common. They also lead to a slightly worse quality, though they are possible to train.
def denoise(self, noisy_images, signal_rates, noise_rates):
pred_images = self.network([noisy_images, signal_rates])
pred_noises = noise_rates ** -0.5 * (noisy_images - signal_rates ** 0.5 * pred_images)
return pred_images, pred_noises
### Using Mean Squared Error instead of Mean Absolute Error
Mean squared error leads to lower quality with more diverse samples as detailed in an earlier section.
```markdown

Sampling Ablations
Stochastic Sampling (DDPM)
Stochastic (DDPM) sampling improves generation quality at high enough number of diffusion steps (>50).
Varying the Number of Sampling Steps
By commparind DDIM and DDPM sampling, we see wildly different behaviour at different number of sampling steps. While DDIM produces reasonable results at even very low reverse diffusion steps, DDPm seems to need more to work well.
Varying the Time Margin
A too small time margin seems to lead to saturated colors in this case.
Different Diffusion Schedules
def diffusion_schedule(self, diffusion_times):
# cosine schedule
signal_rates = tf.cos(math.pi / 2 * diffusion_times) ** 2
noise_rates = 1 - signal_rates
return signal_rates, noise_rates
def diffusion_schedule(self, diffusion_times, min_signal_rate=0.01):
# gaussian schedule
signal_rates = min_signal_rate ** (diffusion_times ** 2)
noise_rates = 1 - signal_rates
return signal_rates, noise_rates
def diffusion_schedule(self, diffusion_times, min_noise_rate=0.01):
# flipped gaussian schedule
noise_rates = min_noise_rate ** ((1 - diffusion_times) ** 2)
signal_rates = 1 - noise_rates
return signal_rates, noise_rates
I have also experimented with different sampling schedules, but my general experience was that they produced similar results. The rows correspond to the respective schedules in the code snippets.
Network Architecture Ablations
Omitting Signal Rate Embedding
Omitting the signal rate embedding from the neural network completely can lead to overly noise and overly smoothed results.
Omitting Skip Connections
Omitting skip connections from the network to fail learning completely.
Omitting Residual Connections
Omitting residual connections from the network slightly degraded performance.
Conclusion
In this work I provided a novel viewpoint and a simplistic implementation and explanation of denoising diffusion models. My hope that this will lower the barrier required to step into the field and start experimenting with these generative models.
A Complete Implementation
import math
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers
# hyperparameters
# data
crop_size = 140 # center crop size of the images
image_size = 64 # training resolution
# network
num_resolutions = 3 # number of stages in the network
blocks_per_stage = 2 # number of residual blocks in a stage
base_width = 64 # number of filters at the highest resolution
min_frequency = 1.0 # minimal embedding frequency
max_frequency = 1000.0 # maximal embedding frequency
# optimization
num_epochs = 20
batch_size = 64
learning_rate = 1e-3
weight_decay = 1e-4
ema = 0.999
# sampling
diffusion_steps = 100
time_margin = 0.05
def preprocess_image(data):
# original image dimensions
height = 218
width = 178
# center crop
image = tf.image.crop_to_bounding_box(
data["image"],
(height - crop_size) // 2,
(width - crop_size) // 2,
crop_size,
crop_size,
)
# resize
image = tf.image.resize(
image, size=[image_size, image_size], method="bicubic", antialias=True
)
# scale pixel values in the -1 - 1 range
return tf.clip_by_value(image / 127.5 - 1, -1, 1)
def prepare_dataset(split):
# load celeb_a dataset split
# note: the automatic download can fail sometimes
return (
tfds.load("celeb_a", split=split, shuffle_files=True)
.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
.cache()
.shuffle(10 * batch_size)
.batch(batch_size, drop_remainder=True)
.prefetch(buffer_size=tf.data.AUTOTUNE)
)
train_dataset = prepare_dataset("train")
val_dataset = prepare_dataset("validation")
# augmentation module: only horizontal flips
def build_augmenter():
return keras.Sequential(
[
layers.InputLayer(input_shape=(image_size, image_size, 3)),
layers.RandomFlip(mode="horizontal"),
],
name="augmenter",
)
# network: residual UNet with sinusoidal signal_rate embedding
def build_network():
def EmbeddingLayer(num_frequencies):
def sinusoidal_embedding(x):
log_frequencies = tf.linspace(
tf.math.log(min_frequency), tf.math.log(max_frequency), num_frequencies
)
frequencies = tf.exp(log_frequencies)
angular_speeds = 2 * math.pi * frequencies
embeddings = tf.concat(
[tf.sin(angular_speeds * x), tf.cos(angular_speeds * x)], axis=3
)
return embeddings
def forward(x):
x = layers.Lambda(sinusoidal_embedding)(x)
return x
return forward
def ResidualBlock(width):
def forward(x):
input_width = x.shape[3]
if input_width == width:
residual = x
else:
residual = layers.Conv2D(width, kernel_size=1)(x)
x = layers.BatchNormalization(center=False, scale=False)(x)
x = layers.Conv2D(
width, kernel_size=3, padding="same", activation=keras.activations.swish
)(x)
x = layers.Conv2D(width, kernel_size=3, padding="same")(x)
x = layers.Add()([residual, x])
return x
return forward
def DownStage(width):
def forward(x):
x, skips = x
for _ in range(blocks_per_stage):
x = ResidualBlock(width)(x)
skips.append(x)
x = layers.AveragePooling2D(pool_size=2)(x)
return x
return forward
def UpStage(width):
def forward(x):
x, skips = x
x = layers.UpSampling2D(size=2, interpolation="bilinear")(x)
for _ in range(blocks_per_stage):
x = layers.Concatenate()([x, skips.pop()])
x = ResidualBlock(width)(x)
return x
return forward
images = keras.Input(shape=(image_size, image_size, 3))
signal_rates = keras.Input(shape=(1, 1, 1))
x = layers.Conv2D(base_width, kernel_size=1)(images)
skips = [x]
e = EmbeddingLayer(num_frequencies=base_width // 2)(signal_rates)
e = layers.UpSampling2D(size=image_size, interpolation="nearest")(e)
x = layers.Concatenate()([x, e])
for i in range(num_resolutions):
x = DownStage((i + 1) * base_width)([x, skips])
for _ in range(blocks_per_stage):
x = ResidualBlock((num_resolutions + 1) * base_width)(x)
for i in reversed(range(num_resolutions)):
x = UpStage((i + 1) * base_width)([x, skips])
x = layers.Concatenate()([x, skips.pop()]) # skips is empty after that
output_signal = layers.Conv2D(3, kernel_size=1)(x)
return keras.Model([images, signal_rates], output_signal, name="residual_unet")
class DiffusionModel(keras.Model):
def __init__(self):
super().__init__()
self.augmenter = build_augmenter()
self.network = build_network()
self.ema_network = keras.models.clone_model(self.network)
def compile(self, **kwargs):
super().compile(**kwargs)
# the noise and image reconstruction losses are tracked as metrics
self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
@property
def metrics(self):
return [self.noise_loss_tracker, self.image_loss_tracker]
def diffusion_schedule(self, diffusion_times):
# cosine schedule
signal_rates = tf.cos(math.pi / 2 * diffusion_times) ** 2
noise_rates = 1 - signal_rates
return signal_rates, noise_rates
def denoise(self, noisy_images, signal_rates, noise_rates, training):
# exponential moving average of weights is used during inference
if training:
network = self.network
else:
network = self.ema_network
pred_noises = network([noisy_images, signal_rates], training=training)
pred_images = signal_rates ** -0.5 * (
noisy_images - noise_rates ** 0.5 * pred_noises
)
return pred_images, pred_noises
def diffusion_process(self, initial_noise):
batch_size = tf.shape(initial_noise)[0]
diffusion_times = tf.linspace(1 - time_margin, time_margin, diffusion_steps + 1)
diffusion_times = tf.reshape(
diffusion_times, shape=(diffusion_steps + 1, 1, 1, 1, 1)
)
diffusion_times = tf.broadcast_to(
diffusion_times, shape=(diffusion_steps + 1, batch_size, 1, 1, 1)
)
noisy_images = initial_noise
for step in range(diffusion_steps):
signal_rates, noise_rates = self.diffusion_schedule(diffusion_times[step])
next_signal_rates, next_noise_rates = self.diffusion_schedule(
diffusion_times[step + 1]
)
pred_images, pred_noises = self.denoise(
noisy_images, signal_rates, noise_rates, training=False
)
noisy_images = (
next_signal_rates ** 0.5 * pred_images
+ next_noise_rates ** 0.5 * pred_noises
)
return pred_images
def generate(self, num_images):
initial_noise = tf.random.normal(shape=(num_images, image_size, image_size, 3))
return self.diffusion_process(initial_noise)
def train_step(self, images):
images = self.augmenter(images, training=True)
diffusion_times = tf.random.uniform(
shape=(batch_size, 1, 1, 1), minval=time_margin, maxval=1 - time_margin
)
signal_rates, noise_rates = self.diffusion_schedule(diffusion_times)
noises = tf.random.normal(shape=(batch_size, image_size, image_size, 3))
noisy_images = signal_rates ** 0.5 * images + noise_rates ** 0.5 * noises
with tf.GradientTape() as tape:
pred_images, pred_noises = self.denoise(
noisy_images, signal_rates, noise_rates, training=True
)
noise_loss = self.loss(noises, pred_noises)
image_loss = self.loss(images, pred_images)
gradients = tape.gradient(noise_loss, self.network.trainable_weights)
self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))
self.noise_loss_tracker.update_state(noise_loss)
self.image_loss_tracker.update_state(image_loss)
for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
ema_weight.assign(ema * ema_weight + (1 - ema) * weight)
return {m.name: m.result() for m in self.metrics}
def test_step(self, images):
images = self.augmenter(images, training=False)
diffusion_times = tf.random.uniform(
shape=(batch_size, 1, 1, 1), minval=time_margin, maxval=1 - time_margin
)
signal_rates, noise_rates = self.diffusion_schedule(diffusion_times)
noises = tf.random.normal(shape=(batch_size, image_size, image_size, 3))
noisy_images = signal_rates ** 0.5 * images + noise_rates ** 0.5 * noises
pred_images, pred_noises = self.denoise(
noisy_images, signal_rates, noise_rates, training=False
)
noise_loss = self.loss(noises, pred_noises)
image_loss = self.loss(images, pred_images)
self.noise_loss_tracker.update_state(noise_loss)
self.image_loss_tracker.update_state(image_loss)
return {m.name: m.result() for m in self.metrics}
def plot_images(self, epoch=-1, logs=None, num_rows=4, num_cols=8):
# plotting a batch of generated images
num_images = num_rows * num_cols
generated_images = self.generate(num_images)
generated_images = (1 + generated_images) / 2
generated_images = tf.clip_by_value(generated_images, 0, 1)
plt.figure(figsize=(num_cols * 1.5, num_rows * 1.5))
for row in range(num_rows):
for col in range(num_cols):
index = row * num_cols + col
plt.subplot(num_rows, num_cols, index + 1)
plt.imshow(generated_images[index])
plt.axis("off")
plt.tight_layout()
plt.savefig("images/{}.png".format(epoch + 1))
plt.close()
model = DiffusionModel()
# using Adam optimizer with weight decay and mean absolute error as reconstruction loss
model.compile(
optimizer=tfa.optimizers.AdamW(
learning_rate=learning_rate, weight_decay=weight_decay
),
loss=keras.losses.mean_absolute_error,
)
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath="checkpoints/model",
save_weights_only=True,
monitor="val_n_loss",
mode="min",
save_best_only=True,
)
model.plot_images()
model.fit(
train_dataset,
validation_data=val_dataset,
epochs=num_epochs,
callbacks=[
keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
checkpoint_callback,
],
)