<!DOCTYPE html>
<html lang="en-us">

  <head>
  <link href="http://gmpg.org/xfn/11" rel="profile">
  <meta http-equiv="content-type" content="text/html; charset=utf-8">

  <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1">

  <title>
    
      Diffusion Models Using a Single Equation &middot; The ICLR Blog Track
    
  </title>

  
  <link rel="canonical" href="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/2021/12/01/diffusion-single-equation/">
  

  <link rel="stylesheet" href="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/css/poole.css">
  <link rel="stylesheet" href="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/css/syntax.css">
  <link rel="stylesheet" href="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/css/lanyon.css">
  <link rel="stylesheet" href="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/css/custom.css">
  <link rel="stylesheet" href="https://fonts.googleapis.com/css?family=PT+Serif:400,400italic,700%7CPT+Sans:400">

  <link rel="apple-touch-icon-precomposed" sizes="144x144" href="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/apple-touch-icon-precomposed.png">
  <link rel="shortcut icon" href="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/favicon.ico">

  <link rel="alternate" type="application/rss+xml" title="RSS" href="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/atom.xml">

  

  <script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript" ></script>
 <!-- <script type="text/x-mathjax-config"> MathJax.Hub.Config({ TeX: { equationNumbers: { autoNumber: "AMS" } } }); </script> -->
  <script type="text/x-mathjax-config">
      MathJax.Hub.Config({
        tex2jax: { inlineMath: [ ['$','$'], ["\\(","\\)"] ],
         processEscapes: false
        }
      });
</script>
</head>


  <body>

    <!-- Target for toggling the sidebar `.sidebar-checkbox` is for regular
     styles, `#sidebar-checkbox` for behavior. -->
<input type="checkbox" class="sidebar-checkbox" id="sidebar-checkbox">
<!-- <input type="checkbox" class="sidebar-checkbox" id="sidebar-checkbox" > -->

<!-- Toggleable sidebar -->
<div class="sidebar" id="sidebar">
  <div class="sidebar-item">
    <p>For short-term, peer-sourced tests of time, generalizations, specializations, reproductions, etc.!</p>
  </div>

  <nav class="sidebar-nav">

    

    
    
      
        
          <a class="sidebar-nav-item" href="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/">ICLR 2022 Blog Track</a>
        
      
    
      
        
      
    
      
        
          <a class="sidebar-nav-item" href="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/about/">About</a>
        
      
    
      
    
      
        
      
    
      
        
          <a class="sidebar-nav-item" href="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/submitting/">Submitting</a>
        
      
    
      
        
          <a class="sidebar-nav-item" href="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/tags/">Tags</a>
        
      
    

    <a class="sidebar-nav-item" href="https://github.com/iclr-blog-track/iclr-blog-track.github.io">GitHub project</a>
    <span class="sidebar-nav-item">Currently vICLR Spring 2021</span>
  </nav>

  <div class="sidebar-item">
    <p>
      &copy; 2022. All rights reserved.
    </p>
  </div>
</div>


    <!-- Wrap is the content to shift when toggling the sidebar. We wrap the
         content to avoid any CSS collisions with our real content. -->
    <div class="wrap">
      <div class="masthead">
        <div class="container">
          <h3 class="masthead-title">
            <a href="/" title="Home">The ICLR Blog Track</a>
            <small></small>
          </h3>
        </div>
      </div>

      <div class="container content">
        <div class="post">
  <h1 id="iclr-post-title" class="post-title">Diffusion Models Using a Single Equation</h1>
  <span class="post-date">01 Dec 2021 | 
    <a class="content-tag" href="/tags/#generative-modeling"> generative modeling </a>
  
    <a class="content-tag" href="/tags/#denoising-diffusion"> denoising diffusion </a>
  
    <a class="content-tag" href="/tags/#ddim"> DDIM </a>
  
    <a class="content-tag" href="/tags/#ddpm"> DDPM </a>
  </span>

  <span id="iclr-post-authors" class="post-date">Anonymous</span>
  <h1 id="introduction">Introduction</h1>

<p>Recently, <strong>denoising diffusion models</strong> <a href="https://arxiv.org/abs/2006.11239">Ho et al. (2020)</a> and <strong>noise conditional score networks</strong> <a href="https://arxiv.org/abs/1907.05600">Song et al. (2019)</a> have been shown to be a powerful class of generative models, that can rival even generative adversarial networks (GANs) <a href="https://arxiv.org/abs/2105.05233">Dhariwal &amp; Nicol (2021)</a> in image synthesis quality, while being more stable to train than GANs <a href="https://arxiv.org/abs/1406.2661">Goodfellow et al. (2014)</a>. One of their drawback is however, that they are slower to sample from, because they require multiple, sometimes even hundreds of forward passes per generated image <a href="https://arxiv.org/abs/2006.11239">Ho et al. (2020)</a>.</p>

<h2 id="motivation">Motivation</h2>

<p>A difficulty I faced when learning about diffusion models was that that the entry barrier to understanding them is quite high, their theoretical background is mathematically convoluted and can be difficult to follow, sometimes even with conflicting notations accross works. My aim with this blogpost is to lower this barrier, to make diffusion models, specifically Denosing Diffusion Implicit Models (DDIMs) <a href="https://arxiv.org/abs/2010.02502">Song et al. (2020)</a> more accessible by providing an explanation of their inner workings that is mathematically very light (only uses a single equation, and carefully named variables, no letter notations), and try to understand them from a different point of view compared to their paper. I would also like to give the readers novel intuitions about these models, and to provide an end-to-end, simple to use implementation for diffusioin models, that is easy to customize and can be a good starting point for future projects on this subject.</p>

<h2 id="intuitions">Intuitions</h2>

<p>A well known intuition is that what these models actually do, is that they estimate the data distribution <a href="https://yang-song.github.io/blog/2021/score/">Song (2021)</a>, and that they can generate from these distibutions by following the gradient of the log probability density function of the data.</p>

<p>In this work my aim is to present a different, more practical intuition: we force these models to sample from the data distribution by <strong>misleading</strong> them.</p>

<h2 id="diffusion-process">Diffusion Process</h2>

<p>Let us imagine a picture, which over time gets mixed with more and more noise, so much so, that after sufficiently long time it is indistuinguishable from pure pixel noise. This is a diffusion process, where at the start (at diffusion time = 0), all of our current signal consists of the original image, and as time goes by, all of the signal becomes noise (at diffusion time = 1).</p>

<p>We model this process by describing it with a diffusion schedule, which maps diffusion timestamps to corresponding signal rates and noise rates, which describe the power ratio in the signal processing sense (assuming, that they both have the same power in themselves) of the signal and noise respectively, in comparison to the combined noisy signal.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">signal_rate</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">noise_rate</span>
<span class="n">noise_rate</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">signal_rate</span>
</code></pre></div></div>

<h2 id="mixing-equation">Mixing Equation</h2>

<p>The equation that we will use is this work is the following, the mixing equation:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">noisy_images</span> <span class="o">=</span> <span class="n">signal_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">images</span> <span class="o">+</span> <span class="n">noise_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">noises</span>
</code></pre></div></div>

<p>It describes how the signal gets combined with noise, creating a noisy signal. In this case I consider images as the signal, and pixel-noise as the noise, as this is what we will be dealing with in this work.</p>

<h1 id="algorithm">Algorithm</h1>
<p>In the following sections, I will use simplified Keras-based code snippets to present the algorithm of a <strong>Denoising Diffusion Implicit Model</strong> <a href="https://arxiv.org/abs/2010.02502">Song et al. (2020)</a>. Note that these code snippets would not run in themselves, as I have omitted some variable and method declarations, function arguments and shape manipulations for clarity. At the end of this blogpost I present a feature-complete end-to-end implementation, which can be used for experimentation.</p>

<p>This is the model class that we will use for implementing the algorithm:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">DiffusionModel</span><span class="p">(</span><span class="n">keras</span><span class="p">.</span><span class="n">Model</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">diffusion_steps</span><span class="p">,</span> <span class="n">time_margin</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">diffusion_steps</span> <span class="o">=</span> <span class="n">diffusion_steps</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">time_margin</span> <span class="o">=</span> <span class="n">time_margin</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">network</span> <span class="o">=</span> <span class="n">build_network</span><span class="p">()</span>

    <span class="k">def</span> <span class="nf">compile</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="nb">compile</span><span class="p">()</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer</span>
</code></pre></div></div>

<h2 id="task">Task</h2>

<p>Diffusion models are trained to solve the task of signal denoising, or more precisely signal-noise separation. Their input is a noisy signal and additional information about the signal and noise rates used to produce it, and they have to predict the original signal and noise values that we mixed together.</p>

<p>There are two ways to solve this task, making sure that the outputs obey to the mixing equation:</p>
<ol>
  <li>The network could either predict the original signal which we can substitute into the mixing equation to calculate the corresponding noise that could have created the noisy input signal when being mixed with the predicted signal.</li>
  <li>It can also equivalently predict the noise that was mixed with the original signal, and we can use the mixing equation to similarly calculate the corresponding original signal.</li>
</ol>

<p>Though theoretically both these solutions are equally sound, in practice we usually use the latter one, predicting the noise, as empirically it leads to a more stable training and higher generation quality <a href="https://arxiv.org/abs/2006.11239">Ho et al. (2020)</a>, but note that I investigate the swapped task of predicting the original images instead of noise in one of the ablations.</p>

<p>So in practice, based on the noisy image and the used signal rate, we predict the noise and simply rearrange the mixing equation to get the predicted original image based on the predicted noise:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>    <span class="k">def</span> <span class="nf">denoise</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">noisy_images</span><span class="p">,</span> <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span><span class="p">):</span>
        <span class="n">pred_noises</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">network</span><span class="p">([</span><span class="n">noisy_images</span><span class="p">,</span> <span class="n">signal_rates</span><span class="p">])</span>
        <span class="n">pred_images</span> <span class="o">=</span> <span class="n">signal_rates</span> <span class="o">**</span> <span class="o">-</span><span class="mf">0.5</span> <span class="o">*</span> <span class="p">(</span><span class="n">noisy_images</span> <span class="o">-</span> <span class="n">noise_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">pred_noises</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">pred_images</span><span class="p">,</span> <span class="n">pred_noises</span>
</code></pre></div></div>

<p>We can see that if a signal rate of 0 would be used, we would have a division by zero at these lines, but in that case the task would also be ill-posed too, as it is impossible to predict the original signal if the noisy signal does not contain any of it, only pure noise. To avoid this issue, I will use time margins when sampling diffusion times, to stay away of the to limits of the process, where our quantities might blow up.</p>

<h2 id="training">Training</h2>

<p>How can we train these diffusion models?</p>

<ul>
  <li>For each training sample, we sample a normal-distributed noise, with the same shape as the images. These will later be mixed to create the noisy signal, and with that the input-output pairs for training these models.
    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">images</span><span class="p">):</span>
      <span class="n">noises</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_channels</span><span class="p">))</span>
</code></pre></div>    </div>
  </li>
  <li>We define the diffusion schedule, which maps a uniformly distributed variable, the diffusion time, to the corresponding signal and noise rates, then we use it to sample a signal and noise rate for each training sample.
    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>      <span class="n">diffusion_times</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,),</span> <span class="n">minval</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">time_margin</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">-</span> <span class="bp">self</span><span class="p">.</span><span class="n">time_margin</span><span class="p">)</span>
      <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">diffusion_schedule</span><span class="p">(</span><span class="n">diffusion_times</span><span class="p">)</span>
</code></pre></div>    </div>
  </li>
  <li>Finally, we use these rates according to the mixing equation to mix the training samples with noise, creating the noisy signals.
    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>      <span class="n">noisy_images</span> <span class="o">=</span> <span class="n">signal_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">images</span> <span class="o">+</span> <span class="n">noise_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">noises</span>
</code></pre></div>    </div>
  </li>
  <li>These noisy signals are then fed into the network, which tries to separate them into the original signal and noise as described in the previous section.
    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>      <span class="k">with</span> <span class="n">tf</span><span class="p">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span>
          <span class="n">pred_images</span><span class="p">,</span> <span class="n">pred_noises</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">denoise</span><span class="p">(</span><span class="n">noisy_images</span><span class="p">,</span> <span class="n">signal_rates</span><span class="p">)</span>
</code></pre></div>    </div>
  </li>
  <li>We then use the (true noise, predicted noise) pairs to calculate a reconstruction loss for each training sample. Theoretically, mean squared error (MSE) should be used here, however in practice mean absolute error (MAE) seems to produce better results <a href="https://arxiv.org/abs/2009.00713">Chen &amp; Zhang et al. (2020)</a>. While MSE seems to lead to more diverse outputs, MAE produces more conservative ones <a href="https://arxiv.org/abs/2111.05826">Saharia et al. (2021)</a>, which lines up with my experience as well, so this is what I use here.
    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>          <span class="n">noise_loss</span> <span class="o">=</span> <span class="n">keras</span><span class="p">.</span><span class="n">losses</span><span class="p">.</span><span class="n">mean_absolute_error</span><span class="p">(</span><span class="n">noises</span><span class="p">,</span> <span class="n">pred_noises</span><span class="p">)</span>
</code></pre></div>    </div>
  </li>
  <li>The loss gets backpropagated, and gradient-based optimization is applied on the network’s weights.
    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>      <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="p">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">noise_loss</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">network</span><span class="p">.</span><span class="n">trainable_weights</span><span class="p">)</span>
      <span class="bp">self</span><span class="p">.</span><span class="n">optimizer</span><span class="p">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">network</span><span class="p">.</span><span class="n">trainable_weights</span><span class="p">))</span>

      <span class="k">return</span> <span class="n">noise_loss</span>
</code></pre></div>    </div>
  </li>
</ul>

<h2 id="sampling">Sampling</h2>

<p>Now, how can we turn these models into generative models? How can we utilize this denoising behaviour for generating from the data distribution?</p>

<p>My intuition, and the main point of this work is, that we <strong>lie</strong> to them, we mislead them, as I will explain in this section.</p>

<p>We sample normally distributed noise, and make the network iteratively denoise it, making it hallucinate a realistic signal by denoising pure noise. So the network is used to estimate a reverse diffusion process from pure noise. This is done via an iterative process, where in each step we use the network to reduce a small amount of noise in the signal, using signal and noise rates given by diffusion times that slowly move back in time, decreasing from almost 1 (completely noise) to almost 0 (completely signal).</p>

<p>But what do we do in the first step?</p>

<p>Recall from the section describing the task that we cannot have a signal rate of exactly zero, as we would run into a division by zero during the denoising step, and also the task would be ill-defined. But in the first step, we have a signal rate of exactly zero.</p>

<p>To resolve this issue, we trick the network, and by telling it that there is a small amount of signal even in the pure noise, by inputing pure noise as the noisy signal, while setting the signal rate slightly above zero, and the noise rate slightly below one. By using time margins, the initial diffusion time will be above zero, which will cause the diffusion schedule to also output a non-zero value as the starting signal rate.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>    <span class="k">def</span> <span class="nf">diffusion_process</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">initial_noise</span><span class="p">):</span>

        <span class="n">noisy_images</span> <span class="o">=</span> <span class="n">initial_noise</span>

        <span class="n">diffusion_times</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="p">.</span><span class="n">time_margin</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">time_margin</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">diffusion_steps</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
        <span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">diffusion_steps</span><span class="p">):</span>
            <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">diffusion_schedule</span><span class="p">(</span><span class="n">diffusion_times</span><span class="p">[</span><span class="n">step</span><span class="p">])</span>
            <span class="n">pred_images</span><span class="p">,</span> <span class="n">pred_noises</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">denoise</span><span class="p">(</span><span class="n">noisy_images</span><span class="p">,</span> <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span><span class="p">)</span>
</code></pre></div></div>
<p>But what should be the noisy signal in the following step?</p>

<p>Since we have a latest estimate of the signal and noise in the current step, we can recombine these using the signal and noise rate of the following step, to get our best estimate of what the noisy signal would be if it had a slightly different signal-to-noise ratio.</p>

<p>Iterating these steps from a diffusion time of almost one (pure noise) to almost zero (pure signal), will utilize the network to gradually denoise pure noise into something that it considers a real signal.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>            <span class="n">next_signal_rates</span><span class="p">,</span> <span class="n">next_noise_rates</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">diffusion_schedule</span><span class="p">(</span><span class="n">diffusion_times</span><span class="p">[</span><span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">])</span>
            <span class="n">noisy_images</span> <span class="o">=</span> <span class="n">next_signal_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">pred_images</span> <span class="o">+</span> <span class="n">next_noise_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">pred_noises</span>

        <span class="k">return</span> <span class="n">pred_images</span>
</code></pre></div></div>
<p>And with that, we arrive at a sampling procedure that is equivalent to the one described in the Denoising Diffusion Implicit Models paper, but based on very different considerations.</p>

<h2 id="network">Network</h2>
<p><a href="https://arxiv.org/abs/2006.11239">Ho et al. (2020)</a> proposes the usage of a U-Net network <a href="https://arxiv.org/abs/1505.04597">Ronneberger et al. 2015</a>, which makes our network an overcomplete denoising autoencoder, overcomplete because its latent dimensionality is higher than that of the data, denoising because it tries to reconstruct the original data from its corrupted inputs.</p>

<p>During my experiments I found the following 3 properties of the neural network architecture to be the most important:</p>
<ul>
  <li>the network should be a U-Net, i.e. it should downsample and the upsample the input data, while also containing skip connections from its first half of layers to the layers in its second half with the same resolution.</li>
  <li>each stage of the network (a contiguous set of layers that operate on the same resolution) should consist of residual blocks, so the flow of information in the network should not only be helped with large skip, but with small residual connections as well.</li>
  <li>The signal rates should be embedded, using a sinusoidal embedding layer, which is known as positional encoding in Transformers <a href="https://arxiv.org/abs/1706.03762">Vaswani et al. (2017)</a> and Neural Radiance Fields <a href="https://arxiv.org/abs/2003.08934">Mildenhall &amp; Srinivasan &amp; Tancik et al. (2020)</a>.</li>
</ul>

<p>Also, I did run into occasional diverged trainings, expecially when increasing the network’s size, which was suprising based on the simplicity of the training procedure. I have found that the following methods help with training stability:</p>
<ul>
  <li>weight decay (using AdamW <a href="https://arxiv.org/abs/1711.05101">Loshchilov et al. (2017)</a> instead of Adam <a href="https://arxiv.org/abs/1412.6980">Kingma and Ba (2014)</a> )</li>
  <li>layer normalization <a href="https://arxiv.org/abs/1607.06450">Ba et al. (2016)</a></li>
  <li>batch normalization <a href="https://arxiv.org/abs/1502.03167">Ioffe et al. (2015)</a>
In the reference implementation at the end of this blogpost I use a combination of weight decay and batch normalization.</li>
</ul>

<p>The following is a simplified implementation of the recommended neural network architecture, using the <a href="https://keras.io/guides/functional_api/">Keras Functional API</a>:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">build_network</span><span class="p">():</span>
    <span class="n">images</span> <span class="o">=</span> <span class="n">keras</span><span class="p">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_channels</span><span class="p">))</span>
    <span class="n">signal_rates</span> <span class="o">=</span> <span class="n">keras</span><span class="p">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,))</span>
    <span class="n">signal_rate_embeddings</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Lambda</span><span class="p">(</span><span class="n">sinusoidal_embedding</span><span class="p">)(</span><span class="n">signal_rates</span><span class="p">)</span>

    <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Concatenate</span><span class="p">()([</span><span class="n">images</span><span class="p">,</span> <span class="n">signal_rate_embeddings</span><span class="p">])</span>
    <span class="n">skips</span> <span class="o">=</span> <span class="p">[</span><span class="bp">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">depth</span>

    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">depth</span><span class="p">):</span>
        <span class="n">x</span><span class="p">,</span> <span class="n">skips</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">DownStage</span><span class="p">(</span><span class="n">residual</span><span class="o">=</span><span class="bp">True</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span>

    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">depth</span><span class="p">)):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">UpStage</span><span class="p">(</span><span class="n">residual</span><span class="o">=</span><span class="bp">True</span><span class="p">)([</span><span class="n">x</span><span class="p">,</span> <span class="n">skips</span><span class="p">[</span><span class="n">i</span><span class="p">]])</span>

    <span class="n">output_signal</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">image_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">keras</span><span class="p">.</span><span class="n">Model</span><span class="p">([</span><span class="n">images</span><span class="p">,</span> <span class="n">signal_rates</span><span class="p">],</span> <span class="n">output_signal</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">"unet"</span><span class="p">)</span>
</code></pre></div></div>

<p>Why is the sinusoidal embedding a crucial part of the architecture? What is its role?</p>

<p>To sum it up, it helps the network be highly sensitive to its value (I recommend <a href="https://arxiv.org/abs/2006.10739">Tancik &amp; Srinivasan &amp; Mildenhall et al. (2020)</a> to interested readers.), which is useful, as theory suggests that ideally separate networks should be used for each denoising step, which is not realistic in practice.</p>

<p>The following code snippet shows a minimalistic implementation of sinusoidal embeddings:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">sinusoidal_embedding</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
    <span class="n">log_frequencies</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">min_frequency</span><span class="p">),</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">max_frequency</span><span class="p">),</span> <span class="n">num_frequencies</span><span class="p">)</span>
    <span class="n">frequencies</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_frequencies</span><span class="p">)</span>
    <span class="n">angular_speeds</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">math</span><span class="p">.</span><span class="n">pi</span> <span class="o">*</span> <span class="n">frequencies</span>
    <span class="n">embeddings</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">tf</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="n">angular_speeds</span> <span class="o">*</span> <span class="n">x</span><span class="p">),</span> <span class="n">tf</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">angular_speeds</span> <span class="o">*</span> <span class="n">x</span><span class="p">)])</span>
    <span class="k">return</span> <span class="n">embeddings</span>
</code></pre></div></div>

<h1 id="stochastic-sampling">Stochastic Sampling</h1>

<p>Even though historically Denoising Diffusion Probabilistic Models (DDPMs) preceeded the implicit models (DDIMs), following the line of reasoning proposed in this work, they can be interpreted as an extension of DDIMs. I should note here, the training DDPMs is the exact same procedure as training DDIMs, the only real difference between the two is the sampling procedure.</p>

<p>Though I did not manage to derive exactly the DDPM sampling procedure from the type of reasoning presented here, I can still provide some intuitions of how I think DDPM models work by dissecting their sampling procedure into smaller and easier-to-reason-about parts.</p>

<p>What is a common issue with autoregressive generative models, that use their own previous outputs as inputs over an iterative procedure?</p>

<p>If the quality of one of their outputs turns out to be suboptimal, in the next step, since the input is not contained in the training data distribution, the quality of the next output might get even worse, and generative model can quickly wander of the distribution where it works reliably, an can start generating low quality outputs.</p>

<p>How could we counteract that in our sampling procedure?</p>

<p>Let us take advantage of the fact that the sum of two normally distributed random variables is also normally distributed (<a href="https://en.wikipedia.org/wiki/Sum_of_normally_distributed_random_variables">if they are independent</a>)! This means, that if the predicted noise is normally distributed, we remain in distribution if we add some extra normally distributed noise, provided that we slightly downscale predicted noise, so that the resulting noise components noise rate changes as implenented. If the predicted noise was not normally distributed, my intuition is that it gets pushed slightly closer to a normal distribution, which helps with increasing sample quality.</p>

<p>This modified training procedure might improve sample quality as described above, but one of its downsides is that using it, the sampling process becomes stochastic, while the DDIM sampling procedure was fully deterministic.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>    <span class="k">def</span> <span class="nf">diffusion_process</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">initial_noise</span><span class="p">):</span>
        <span class="n">diffusion_times</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="bp">self</span><span class="p">.</span><span class="n">time_margin</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">time_margin</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">diffusion_steps</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>

        <span class="n">noisy_images</span> <span class="o">=</span> <span class="n">initial_noise</span>

        <span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">diffusion_steps</span><span class="p">):</span>
            <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">diffusion_schedule</span><span class="p">(</span><span class="n">diffusion_times</span><span class="p">[</span><span class="n">step</span><span class="p">])</span>
            <span class="n">next_signal_rates</span><span class="p">,</span> <span class="n">next_noise_rates</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">diffusion_schedule</span><span class="p">(</span><span class="n">diffusion_times</span><span class="p">[</span><span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">])</span> 

            <span class="n">pred_images</span><span class="p">,</span> <span class="n">pred_noises</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">denoise</span><span class="p">(</span><span class="n">noisy_images</span><span class="p">,</span> <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span><span class="p">)</span>
</code></pre></div></div>
<p>The following lines present an implementation of DDPM sampling in this framework.</p>
<ul>
  <li>we generate extra noise, and also calculate its rate to be a small value</li>
  <li>then we also calculate a noise rate multiplier which should be slightly below one</li>
  <li>finally we mix the elements but based on a modified version of the single equation, where the noise rate is decrasead while an additional extra noise rate is added to estimate the next noisy image
    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>          <span class="n">extra_noises</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>        
          <span class="n">extra_noise_rates</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">signal_rates</span> <span class="o">/</span> <span class="n">next_signal_rates</span>
          <span class="n">noise_rate_multipliers</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">extra_noise_rates</span> <span class="o">/</span> <span class="n">noise_rates</span>

          <span class="n">noisy_images</span> <span class="o">=</span> <span class="p">(</span>
              <span class="n">next_signal_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">pred_images</span>
              <span class="o">+</span> <span class="p">(</span><span class="n">next_noise_rates</span> <span class="o">*</span> <span class="n">noise_rate_multipliers</span><span class="p">)</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">pred_noises</span>
              <span class="o">+</span> <span class="n">extra_noise_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">extra_noises</span>
          <span class="p">)</span>

      <span class="k">return</span> <span class="n">pred_images</span>
</code></pre></div>    </div>
  </li>
</ul>

<h1 id="applications">Applications</h1>

<p>The fact that the DDIM sampling procedure is deterministic lands it to some interesting use cases.</p>

<h2 id="noise-space-interpolation">Noise Space Interpolation</h2>
<p>One can carry out interpolation between two images in noise space space by starting calculating intermediate points between their starting noise values.</p>

<p><img src="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/images/2021-12-01-diffusion-single-equation/interpolation.png" alt="noise space interpolation" /></p>

<h2 id="encoding-and-decoding-images">Encoding and Decoding Images</h2>
<p>We can iteratively encode images into pure noise using a forward diffusion process, and then decode them and compare the reconstructions with the  original images. On this figure the top row contains the original images, the middle one their encoded final noise values, and the bottom one their reconstructions, using these noises as starting point of a reverse diffusion process.</p>

<p><img src="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/images/2021-12-01-diffusion-single-equation/encode.png" alt="encode decode" /></p>

<h1 id="ablations">Ablations</h1>

<h2 id="baseline">Baseline</h2>
<p>The following is the baseline image generation quality, following exactly the settings found in the end-to-end implementation at the end of this blogpost (training takes around an hour on a single A100 GPU).</p>

<p><img src="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/images/2021-12-01-diffusion-single-equation/baseline.png" alt="baseline visualization" /></p>

<h2 id="training-ablations">Training Ablations</h2>

<h3 id="predicting-images-instead-of-noise">Predicting Images Instead of Noise</h3>
<p>In my experience, the version of diffusion models that predict the images instead of noise, are tougher to train, with complete and temporary divergence events being much more common. They also lead to a slightly worse quality, though they are possible to train.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>    <span class="k">def</span> <span class="nf">denoise</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">noisy_images</span><span class="p">,</span> <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span><span class="p">):</span>
        <span class="n">pred_images</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">network</span><span class="p">([</span><span class="n">noisy_images</span><span class="p">,</span> <span class="n">signal_rates</span><span class="p">])</span>
        <span class="n">pred_noises</span> <span class="o">=</span> <span class="n">noise_rates</span> <span class="o">**</span> <span class="o">-</span><span class="mf">0.5</span> <span class="o">*</span> <span class="p">(</span><span class="n">noisy_images</span> <span class="o">-</span> <span class="n">signal_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">pred_images</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">pred_images</span><span class="p">,</span> <span class="n">pred_noises</span>
</code></pre></div></div>

<p><img src="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/images/2021-12-01-diffusion-single-equation/pred_image.png" alt="training predict images" /></p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
### Using Mean Squared Error instead of Mean Absolute Error

Mean squared error leads to lower quality with more diverse samples as detailed in an earlier section.

```markdown
![training mse loss]({{ site.url }}/public/images/2021-12-01-diffusion-single-equation/mse.png) 
</code></pre></div></div>

<h2 id="sampling-ablations">Sampling Ablations</h2>

<h3 id="stochastic-sampling-ddpm">Stochastic Sampling (DDPM)</h3>

<p>Stochastic (DDPM) sampling improves generation quality at high enough number of diffusion steps (&gt;50).</p>

<p><img src="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/images/2021-12-01-diffusion-single-equation/ddpm.png" alt="stochastic sampling ddpm" /></p>

<h3 id="varying-the-number-of-sampling-steps">Varying the Number of Sampling Steps</h3>

<p>By commparind DDIM and DDPM sampling, we see wildly different behaviour at different number of sampling steps. While DDIM produces reasonable results at even very low reverse diffusion steps, DDPm seems to need more to work well.</p>

<p><img src="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/images/2021-12-01-diffusion-single-equation/steps_ddim.png" alt="sampling steps ddim" /></p>

<p><img src="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/images/2021-12-01-diffusion-single-equation/steps_ddpm.png" alt="sampling steps ddpm" /></p>

<h3 id="varying-the-time-margin">Varying the Time Margin</h3>

<p>A too small time margin seems to lead to saturated colors in this case.</p>

<p><img src="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/images/2021-12-01-diffusion-single-equation/margin_steps.png" alt="margin 10 steps" /></p>

<h3 id="different-diffusion-schedules">Different Diffusion Schedules</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>    <span class="k">def</span> <span class="nf">diffusion_schedule</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">diffusion_times</span><span class="p">):</span>
        <span class="c1"># cosine schedule
</span>        <span class="n">signal_rates</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">math</span><span class="p">.</span><span class="n">pi</span> <span class="o">/</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">diffusion_times</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
        <span class="n">noise_rates</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">signal_rates</span>
        <span class="k">return</span> <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>    <span class="k">def</span> <span class="nf">diffusion_schedule</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">diffusion_times</span><span class="p">,</span> <span class="n">min_signal_rate</span><span class="o">=</span><span class="mf">0.01</span><span class="p">):</span>
        <span class="c1"># gaussian schedule
</span>        <span class="n">signal_rates</span> <span class="o">=</span> <span class="n">min_signal_rate</span> <span class="o">**</span> <span class="p">(</span><span class="n">diffusion_times</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
        <span class="n">noise_rates</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">signal_rates</span>
        <span class="k">return</span> <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>    <span class="k">def</span> <span class="nf">diffusion_schedule</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">diffusion_times</span><span class="p">,</span> <span class="n">min_noise_rate</span><span class="o">=</span><span class="mf">0.01</span><span class="p">):</span>
        <span class="c1"># flipped gaussian schedule
</span>        <span class="n">noise_rates</span> <span class="o">=</span> <span class="n">min_noise_rate</span> <span class="o">**</span> <span class="p">((</span><span class="mi">1</span> <span class="o">-</span> <span class="n">diffusion_times</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
        <span class="n">signal_rates</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">noise_rates</span>
        <span class="k">return</span> <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span>
</code></pre></div></div>

<p>I have also experimented with different sampling schedules, but my general experience was that they produced similar results. The rows correspond to the respective schedules in the code snippets.</p>

<p><img src="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/images/2021-12-01-diffusion-single-equation/schedules.png" alt="diffusion schedules" /></p>

<h2 id="network-architecture-ablations">Network Architecture Ablations</h2>

<h3 id="omitting-signal-rate-embedding">Omitting Signal Rate Embedding</h3>

<p>Omitting the signal rate embedding from the neural network completely can lead to overly noise and overly smoothed results.</p>

<p><img src="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/images/2021-12-01-diffusion-single-equation/no_emb.png" alt="omitted signal rate embedding" /></p>

<h3 id="omitting-skip-connections">Omitting Skip Connections</h3>

<p>Omitting skip connections from the network to fail learning completely.</p>

<p><img src="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/images/2021-12-01-diffusion-single-equation/no_skip.png" alt="omitted skip connections" /></p>

<h3 id="omitting-residual-connections">Omitting Residual Connections</h3>

<p>Omitting residual connections from the network slightly degraded performance.</p>

<p><img src="https://iclr.iro.umontreal.ca/298b26fe-cc4f-4f10-9499-daf9fb61bf84_1642249335/public/images/2021-12-01-diffusion-single-equation/no_res.png" alt="omitted residual connections" /></p>

<h1 id="conclusion">Conclusion</h1>

<p>In this work I provided a novel viewpoint and a simplistic implementation and explanation of denoising diffusion models. My hope that this will lower the barrier required to step into the field and start experimenting with these generative models.</p>

<h1 id="a-complete-implementation">A Complete Implementation</h1>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="kn">import</span> <span class="nn">math</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">import</span> <span class="nn">tensorflow_datasets</span> <span class="k">as</span> <span class="n">tfds</span>
<span class="kn">import</span> <span class="nn">tensorflow_addons</span> <span class="k">as</span> <span class="n">tfa</span>

<span class="kn">from</span> <span class="nn">tensorflow</span> <span class="kn">import</span> <span class="n">keras</span>
<span class="kn">from</span> <span class="nn">tensorflow.keras</span> <span class="kn">import</span> <span class="n">layers</span>

<span class="c1"># hyperparameters
</span>
<span class="c1"># data
</span><span class="n">crop_size</span> <span class="o">=</span> <span class="mi">140</span>  <span class="c1"># center crop size of the images
</span><span class="n">image_size</span> <span class="o">=</span> <span class="mi">64</span>  <span class="c1"># training resolution
</span>
<span class="c1"># network
</span><span class="n">num_resolutions</span> <span class="o">=</span> <span class="mi">3</span>  <span class="c1"># number of stages in the network
</span><span class="n">blocks_per_stage</span> <span class="o">=</span> <span class="mi">2</span>  <span class="c1"># number of residual blocks in a stage
</span><span class="n">base_width</span> <span class="o">=</span> <span class="mi">64</span>  <span class="c1"># number of filters at the highest resolution
</span><span class="n">min_frequency</span> <span class="o">=</span> <span class="mf">1.0</span>  <span class="c1"># minimal embedding frequency
</span><span class="n">max_frequency</span> <span class="o">=</span> <span class="mf">1000.0</span>  <span class="c1"># maximal embedding frequency
</span>
<span class="c1"># optimization
</span><span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">20</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span>
<span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">1e-3</span>
<span class="n">weight_decay</span> <span class="o">=</span> <span class="mf">1e-4</span>
<span class="n">ema</span> <span class="o">=</span> <span class="mf">0.999</span>

<span class="c1"># sampling
</span><span class="n">diffusion_steps</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">time_margin</span> <span class="o">=</span> <span class="mf">0.05</span>


<span class="k">def</span> <span class="nf">preprocess_image</span><span class="p">(</span><span class="n">data</span><span class="p">):</span>
    <span class="c1"># original image dimensions
</span>    <span class="n">height</span> <span class="o">=</span> <span class="mi">218</span>
    <span class="n">width</span> <span class="o">=</span> <span class="mi">178</span>
    <span class="c1"># center crop
</span>    <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">image</span><span class="p">.</span><span class="n">crop_to_bounding_box</span><span class="p">(</span>
        <span class="n">data</span><span class="p">[</span><span class="s">"image"</span><span class="p">],</span>
        <span class="p">(</span><span class="n">height</span> <span class="o">-</span> <span class="n">crop_size</span><span class="p">)</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span>
        <span class="p">(</span><span class="n">width</span> <span class="o">-</span> <span class="n">crop_size</span><span class="p">)</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span>
        <span class="n">crop_size</span><span class="p">,</span>
        <span class="n">crop_size</span><span class="p">,</span>
    <span class="p">)</span>
    <span class="c1"># resize
</span>    <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">image</span><span class="p">.</span><span class="n">resize</span><span class="p">(</span>
        <span class="n">image</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">[</span><span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">],</span> <span class="n">method</span><span class="o">=</span><span class="s">"bicubic"</span><span class="p">,</span> <span class="n">antialias</span><span class="o">=</span><span class="bp">True</span>
    <span class="p">)</span>
    <span class="c1"># scale pixel values in the -1 - 1 range
</span>    <span class="k">return</span> <span class="n">tf</span><span class="p">.</span><span class="n">clip_by_value</span><span class="p">(</span><span class="n">image</span> <span class="o">/</span> <span class="mf">127.5</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>


<span class="k">def</span> <span class="nf">prepare_dataset</span><span class="p">(</span><span class="n">split</span><span class="p">):</span>
    <span class="c1"># load celeb_a dataset split
</span>    <span class="c1"># note: the automatic download can fail sometimes
</span>    <span class="k">return</span> <span class="p">(</span>
        <span class="n">tfds</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="s">"celeb_a"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="n">split</span><span class="p">,</span> <span class="n">shuffle_files</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="p">.</span><span class="nb">map</span><span class="p">(</span><span class="n">preprocess_image</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">AUTOTUNE</span><span class="p">)</span>
        <span class="p">.</span><span class="n">cache</span><span class="p">()</span>
        <span class="p">.</span><span class="n">shuffle</span><span class="p">(</span><span class="mi">10</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
        <span class="p">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">drop_remainder</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="p">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">buffer_size</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">AUTOTUNE</span><span class="p">)</span>
    <span class="p">)</span>


<span class="n">train_dataset</span> <span class="o">=</span> <span class="n">prepare_dataset</span><span class="p">(</span><span class="s">"train"</span><span class="p">)</span>
<span class="n">val_dataset</span> <span class="o">=</span> <span class="n">prepare_dataset</span><span class="p">(</span><span class="s">"validation"</span><span class="p">)</span>

<span class="c1"># augmentation module: only horizontal flips
</span><span class="k">def</span> <span class="nf">build_augmenter</span><span class="p">():</span>
    <span class="k">return</span> <span class="n">keras</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
        <span class="p">[</span>
            <span class="n">layers</span><span class="p">.</span><span class="n">InputLayer</span><span class="p">(</span><span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">)),</span>
            <span class="n">layers</span><span class="p">.</span><span class="n">RandomFlip</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s">"horizontal"</span><span class="p">),</span>
        <span class="p">],</span>
        <span class="n">name</span><span class="o">=</span><span class="s">"augmenter"</span><span class="p">,</span>
    <span class="p">)</span>


<span class="c1"># network: residual UNet with sinusoidal signal_rate embedding
</span><span class="k">def</span> <span class="nf">build_network</span><span class="p">():</span>
    <span class="k">def</span> <span class="nf">EmbeddingLayer</span><span class="p">(</span><span class="n">num_frequencies</span><span class="p">):</span>
        <span class="k">def</span> <span class="nf">sinusoidal_embedding</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
            <span class="n">log_frequencies</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span>
                <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">min_frequency</span><span class="p">),</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">max_frequency</span><span class="p">),</span> <span class="n">num_frequencies</span>
            <span class="p">)</span>
            <span class="n">frequencies</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_frequencies</span><span class="p">)</span>
            <span class="n">angular_speeds</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">math</span><span class="p">.</span><span class="n">pi</span> <span class="o">*</span> <span class="n">frequencies</span>
            <span class="n">embeddings</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">(</span>
                <span class="p">[</span><span class="n">tf</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="n">angular_speeds</span> <span class="o">*</span> <span class="n">x</span><span class="p">),</span> <span class="n">tf</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">angular_speeds</span> <span class="o">*</span> <span class="n">x</span><span class="p">)],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">3</span>
            <span class="p">)</span>
            <span class="k">return</span> <span class="n">embeddings</span>

        <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Lambda</span><span class="p">(</span><span class="n">sinusoidal_embedding</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span>
            <span class="k">return</span> <span class="n">x</span>

        <span class="k">return</span> <span class="n">forward</span>

    <span class="k">def</span> <span class="nf">ResidualBlock</span><span class="p">(</span><span class="n">width</span><span class="p">):</span>
        <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
            <span class="n">input_width</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span>
            <span class="k">if</span> <span class="n">input_width</span> <span class="o">==</span> <span class="n">width</span><span class="p">:</span>
                <span class="n">residual</span> <span class="o">=</span> <span class="n">x</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="n">residual</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">width</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">BatchNormalization</span><span class="p">(</span><span class="n">center</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="bp">False</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Conv2D</span><span class="p">(</span>
                <span class="n">width</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s">"same"</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">keras</span><span class="p">.</span><span class="n">activations</span><span class="p">.</span><span class="n">swish</span>
            <span class="p">)(</span><span class="n">x</span><span class="p">)</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">width</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s">"same"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">residual</span><span class="p">,</span> <span class="n">x</span><span class="p">])</span>
            <span class="k">return</span> <span class="n">x</span>

        <span class="k">return</span> <span class="n">forward</span>

    <span class="k">def</span> <span class="nf">DownStage</span><span class="p">(</span><span class="n">width</span><span class="p">):</span>
        <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
            <span class="n">x</span><span class="p">,</span> <span class="n">skips</span> <span class="o">=</span> <span class="n">x</span>
            <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">blocks_per_stage</span><span class="p">):</span>
                <span class="n">x</span> <span class="o">=</span> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">width</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span>
                <span class="n">skips</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">AveragePooling2D</span><span class="p">(</span><span class="n">pool_size</span><span class="o">=</span><span class="mi">2</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span>
            <span class="k">return</span> <span class="n">x</span>

        <span class="k">return</span> <span class="n">forward</span>

    <span class="k">def</span> <span class="nf">UpStage</span><span class="p">(</span><span class="n">width</span><span class="p">):</span>
        <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
            <span class="n">x</span><span class="p">,</span> <span class="n">skips</span> <span class="o">=</span> <span class="n">x</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">UpSampling2D</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="s">"bilinear"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span>
            <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">blocks_per_stage</span><span class="p">):</span>
                <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Concatenate</span><span class="p">()([</span><span class="n">x</span><span class="p">,</span> <span class="n">skips</span><span class="p">.</span><span class="n">pop</span><span class="p">()])</span>
                <span class="n">x</span> <span class="o">=</span> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">width</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span>
            <span class="k">return</span> <span class="n">x</span>

        <span class="k">return</span> <span class="n">forward</span>

    <span class="n">images</span> <span class="o">=</span> <span class="n">keras</span><span class="p">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
    <span class="n">signal_rates</span> <span class="o">=</span> <span class="n">keras</span><span class="p">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>

    <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">base_width</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">)(</span><span class="n">images</span><span class="p">)</span>
    <span class="n">skips</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="p">]</span>

    <span class="n">e</span> <span class="o">=</span> <span class="n">EmbeddingLayer</span><span class="p">(</span><span class="n">num_frequencies</span><span class="o">=</span><span class="n">base_width</span> <span class="o">//</span> <span class="mi">2</span><span class="p">)(</span><span class="n">signal_rates</span><span class="p">)</span>
    <span class="n">e</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">UpSampling2D</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">image_size</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="s">"nearest"</span><span class="p">)(</span><span class="n">e</span><span class="p">)</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Concatenate</span><span class="p">()([</span><span class="n">x</span><span class="p">,</span> <span class="n">e</span><span class="p">])</span>

    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_resolutions</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">DownStage</span><span class="p">((</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">base_width</span><span class="p">)([</span><span class="n">x</span><span class="p">,</span> <span class="n">skips</span><span class="p">])</span>

    <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">blocks_per_stage</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">ResidualBlock</span><span class="p">((</span><span class="n">num_resolutions</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">base_width</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span>

    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">num_resolutions</span><span class="p">)):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">UpStage</span><span class="p">((</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">base_width</span><span class="p">)([</span><span class="n">x</span><span class="p">,</span> <span class="n">skips</span><span class="p">])</span>

    <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Concatenate</span><span class="p">()([</span><span class="n">x</span><span class="p">,</span> <span class="n">skips</span><span class="p">.</span><span class="n">pop</span><span class="p">()])</span>  <span class="c1"># skips is empty after that
</span>    <span class="n">output_signal</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">keras</span><span class="p">.</span><span class="n">Model</span><span class="p">([</span><span class="n">images</span><span class="p">,</span> <span class="n">signal_rates</span><span class="p">],</span> <span class="n">output_signal</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">"residual_unet"</span><span class="p">)</span>


<span class="k">class</span> <span class="nc">DiffusionModel</span><span class="p">(</span><span class="n">keras</span><span class="p">.</span><span class="n">Model</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">augmenter</span> <span class="o">=</span> <span class="n">build_augmenter</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">network</span> <span class="o">=</span> <span class="n">build_network</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">ema_network</span> <span class="o">=</span> <span class="n">keras</span><span class="p">.</span><span class="n">models</span><span class="p">.</span><span class="n">clone_model</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">network</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">compile</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="nb">compile</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>

        <span class="c1"># the noise and image reconstruction losses are tracked as metrics
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">noise_loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="p">.</span><span class="n">metrics</span><span class="p">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s">"n_loss"</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">image_loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="p">.</span><span class="n">metrics</span><span class="p">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s">"i_loss"</span><span class="p">)</span>

    <span class="o">@</span><span class="nb">property</span>
    <span class="k">def</span> <span class="nf">metrics</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="p">.</span><span class="n">noise_loss_tracker</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">image_loss_tracker</span><span class="p">]</span>

    <span class="k">def</span> <span class="nf">diffusion_schedule</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">diffusion_times</span><span class="p">):</span>
        <span class="c1"># cosine schedule
</span>        <span class="n">signal_rates</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">math</span><span class="p">.</span><span class="n">pi</span> <span class="o">/</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">diffusion_times</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
        <span class="n">noise_rates</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">signal_rates</span>
        <span class="k">return</span> <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span>

    <span class="k">def</span> <span class="nf">denoise</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">noisy_images</span><span class="p">,</span> <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span><span class="p">,</span> <span class="n">training</span><span class="p">):</span>
        <span class="c1"># exponential moving average of weights is used during inference
</span>        <span class="k">if</span> <span class="n">training</span><span class="p">:</span>
            <span class="n">network</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">network</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">network</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">ema_network</span>

        <span class="n">pred_noises</span> <span class="o">=</span> <span class="n">network</span><span class="p">([</span><span class="n">noisy_images</span><span class="p">,</span> <span class="n">signal_rates</span><span class="p">],</span> <span class="n">training</span><span class="o">=</span><span class="n">training</span><span class="p">)</span>
        <span class="n">pred_images</span> <span class="o">=</span> <span class="n">signal_rates</span> <span class="o">**</span> <span class="o">-</span><span class="mf">0.5</span> <span class="o">*</span> <span class="p">(</span>
            <span class="n">noisy_images</span> <span class="o">-</span> <span class="n">noise_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">pred_noises</span>
        <span class="p">)</span>

        <span class="k">return</span> <span class="n">pred_images</span><span class="p">,</span> <span class="n">pred_noises</span>

    <span class="k">def</span> <span class="nf">diffusion_process</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">initial_noise</span><span class="p">):</span>
        <span class="n">batch_size</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">shape</span><span class="p">(</span><span class="n">initial_noise</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
        <span class="n">diffusion_times</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">time_margin</span><span class="p">,</span> <span class="n">time_margin</span><span class="p">,</span> <span class="n">diffusion_steps</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
        <span class="n">diffusion_times</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span>
            <span class="n">diffusion_times</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">diffusion_steps</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
        <span class="p">)</span>
        <span class="n">diffusion_times</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">broadcast_to</span><span class="p">(</span>
            <span class="n">diffusion_times</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">diffusion_steps</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
        <span class="p">)</span>

        <span class="n">noisy_images</span> <span class="o">=</span> <span class="n">initial_noise</span>
        <span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">diffusion_steps</span><span class="p">):</span>

            <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">diffusion_schedule</span><span class="p">(</span><span class="n">diffusion_times</span><span class="p">[</span><span class="n">step</span><span class="p">])</span>
            <span class="n">next_signal_rates</span><span class="p">,</span> <span class="n">next_noise_rates</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">diffusion_schedule</span><span class="p">(</span>
                <span class="n">diffusion_times</span><span class="p">[</span><span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span>
            <span class="p">)</span>

            <span class="n">pred_images</span><span class="p">,</span> <span class="n">pred_noises</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">denoise</span><span class="p">(</span>
                <span class="n">noisy_images</span><span class="p">,</span> <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="bp">False</span>
            <span class="p">)</span>

            <span class="n">noisy_images</span> <span class="o">=</span> <span class="p">(</span>
                <span class="n">next_signal_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">pred_images</span>
                <span class="o">+</span> <span class="n">next_noise_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">pred_noises</span>
            <span class="p">)</span>

        <span class="k">return</span> <span class="n">pred_images</span>

    <span class="k">def</span> <span class="nf">generate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_images</span><span class="p">):</span>
        <span class="n">initial_noise</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">num_images</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">diffusion_process</span><span class="p">(</span><span class="n">initial_noise</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">images</span><span class="p">):</span>
        <span class="n">images</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">augmenter</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="n">diffusion_times</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span>
            <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">minval</span><span class="o">=</span><span class="n">time_margin</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="mi">1</span> <span class="o">-</span> <span class="n">time_margin</span>
        <span class="p">)</span>
        <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">diffusion_schedule</span><span class="p">(</span><span class="n">diffusion_times</span><span class="p">)</span>

        <span class="n">noises</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
        <span class="n">noisy_images</span> <span class="o">=</span> <span class="n">signal_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">images</span> <span class="o">+</span> <span class="n">noise_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">noises</span>

        <span class="k">with</span> <span class="n">tf</span><span class="p">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span>
            <span class="n">pred_images</span><span class="p">,</span> <span class="n">pred_noises</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">denoise</span><span class="p">(</span>
                <span class="n">noisy_images</span><span class="p">,</span> <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="bp">True</span>
            <span class="p">)</span>

            <span class="n">noise_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">loss</span><span class="p">(</span><span class="n">noises</span><span class="p">,</span> <span class="n">pred_noises</span><span class="p">)</span>
            <span class="n">image_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">loss</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">pred_images</span><span class="p">)</span>

        <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="p">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">noise_loss</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">network</span><span class="p">.</span><span class="n">trainable_weights</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">optimizer</span><span class="p">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">network</span><span class="p">.</span><span class="n">trainable_weights</span><span class="p">))</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">noise_loss_tracker</span><span class="p">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">noise_loss</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">image_loss_tracker</span><span class="p">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">image_loss</span><span class="p">)</span>

        <span class="k">for</span> <span class="n">weight</span><span class="p">,</span> <span class="n">ema_weight</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">network</span><span class="p">.</span><span class="n">weights</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">ema_network</span><span class="p">.</span><span class="n">weights</span><span class="p">):</span>
            <span class="n">ema_weight</span><span class="p">.</span><span class="n">assign</span><span class="p">(</span><span class="n">ema</span> <span class="o">*</span> <span class="n">ema_weight</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">ema</span><span class="p">)</span> <span class="o">*</span> <span class="n">weight</span><span class="p">)</span>

        <span class="k">return</span> <span class="p">{</span><span class="n">m</span><span class="p">.</span><span class="n">name</span><span class="p">:</span> <span class="n">m</span><span class="p">.</span><span class="n">result</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">metrics</span><span class="p">}</span>

    <span class="k">def</span> <span class="nf">test_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">images</span><span class="p">):</span>
        <span class="n">images</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">augmenter</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>

        <span class="n">diffusion_times</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span>
            <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">minval</span><span class="o">=</span><span class="n">time_margin</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="mi">1</span> <span class="o">-</span> <span class="n">time_margin</span>
        <span class="p">)</span>
        <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">diffusion_schedule</span><span class="p">(</span><span class="n">diffusion_times</span><span class="p">)</span>

        <span class="n">noises</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
        <span class="n">noisy_images</span> <span class="o">=</span> <span class="n">signal_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">images</span> <span class="o">+</span> <span class="n">noise_rates</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">noises</span>

        <span class="n">pred_images</span><span class="p">,</span> <span class="n">pred_noises</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">denoise</span><span class="p">(</span>
            <span class="n">noisy_images</span><span class="p">,</span> <span class="n">signal_rates</span><span class="p">,</span> <span class="n">noise_rates</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="bp">False</span>
        <span class="p">)</span>

        <span class="n">noise_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">loss</span><span class="p">(</span><span class="n">noises</span><span class="p">,</span> <span class="n">pred_noises</span><span class="p">)</span>
        <span class="n">image_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">loss</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">pred_images</span><span class="p">)</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">noise_loss_tracker</span><span class="p">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">noise_loss</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">image_loss_tracker</span><span class="p">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">image_loss</span><span class="p">)</span>

        <span class="k">return</span> <span class="p">{</span><span class="n">m</span><span class="p">.</span><span class="n">name</span><span class="p">:</span> <span class="n">m</span><span class="p">.</span><span class="n">result</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">metrics</span><span class="p">}</span>

    <span class="k">def</span> <span class="nf">plot_images</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">num_rows</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_cols</span><span class="o">=</span><span class="mi">8</span><span class="p">):</span>
        <span class="c1"># plotting a batch of generated images
</span>        <span class="n">num_images</span> <span class="o">=</span> <span class="n">num_rows</span> <span class="o">*</span> <span class="n">num_cols</span>

        <span class="n">generated_images</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">generate</span><span class="p">(</span><span class="n">num_images</span><span class="p">)</span>
        <span class="n">generated_images</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">generated_images</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span>
        <span class="n">generated_images</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">clip_by_value</span><span class="p">(</span><span class="n">generated_images</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>

        <span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="n">num_cols</span> <span class="o">*</span> <span class="mf">1.5</span><span class="p">,</span> <span class="n">num_rows</span> <span class="o">*</span> <span class="mf">1.5</span><span class="p">))</span>
        <span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_rows</span><span class="p">):</span>
            <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_cols</span><span class="p">):</span>
                <span class="n">index</span> <span class="o">=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">num_cols</span> <span class="o">+</span> <span class="n">col</span>
                <span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="n">num_rows</span><span class="p">,</span> <span class="n">num_cols</span><span class="p">,</span> <span class="n">index</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
                <span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">generated_images</span><span class="p">[</span><span class="n">index</span><span class="p">])</span>
                <span class="n">plt</span><span class="p">.</span><span class="n">axis</span><span class="p">(</span><span class="s">"off"</span><span class="p">)</span>
        <span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
        <span class="n">plt</span><span class="p">.</span><span class="n">savefig</span><span class="p">(</span><span class="s">"images/{}.png"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">epoch</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
        <span class="n">plt</span><span class="p">.</span><span class="n">close</span><span class="p">()</span>


<span class="n">model</span> <span class="o">=</span> <span class="n">DiffusionModel</span><span class="p">()</span>
<span class="c1"># using Adam optimizer with weight decay and mean absolute error as reconstruction loss
</span><span class="n">model</span><span class="p">.</span><span class="nb">compile</span><span class="p">(</span>
    <span class="n">optimizer</span><span class="o">=</span><span class="n">tfa</span><span class="p">.</span><span class="n">optimizers</span><span class="p">.</span><span class="n">AdamW</span><span class="p">(</span>
        <span class="n">learning_rate</span><span class="o">=</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">weight_decay</span>
    <span class="p">),</span>
    <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="p">.</span><span class="n">losses</span><span class="p">.</span><span class="n">mean_absolute_error</span><span class="p">,</span>
<span class="p">)</span>

<span class="n">checkpoint_callback</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">callbacks</span><span class="p">.</span><span class="n">ModelCheckpoint</span><span class="p">(</span>
    <span class="n">filepath</span><span class="o">=</span><span class="s">"checkpoints/model"</span><span class="p">,</span>
    <span class="n">save_weights_only</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
    <span class="n">monitor</span><span class="o">=</span><span class="s">"val_n_loss"</span><span class="p">,</span>
    <span class="n">mode</span><span class="o">=</span><span class="s">"min"</span><span class="p">,</span>
    <span class="n">save_best_only</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
<span class="p">)</span>

<span class="n">model</span><span class="p">.</span><span class="n">plot_images</span><span class="p">()</span>

<span class="n">model</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span>
    <span class="n">train_dataset</span><span class="p">,</span>
    <span class="n">validation_data</span><span class="o">=</span><span class="n">val_dataset</span><span class="p">,</span>
    <span class="n">epochs</span><span class="o">=</span><span class="n">num_epochs</span><span class="p">,</span>
    <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span>
        <span class="n">keras</span><span class="p">.</span><span class="n">callbacks</span><span class="p">.</span><span class="n">LambdaCallback</span><span class="p">(</span><span class="n">on_epoch_end</span><span class="o">=</span><span class="n">model</span><span class="p">.</span><span class="n">plot_images</span><span class="p">),</span>
        <span class="n">checkpoint_callback</span><span class="p">,</span>
    <span class="p">],</span>
<span class="p">)</span>
</code></pre></div></div>

</div>

<div id="bibtex-container" class="related">
  For attribution in academic contexts, please cite this work as
  <pre id="bibtex-academic-attribution">

  </pre>

  BibTeX citation
  <pre id="bibtex-box">

  </pre>
</div>
<script>
  let authorsSpan = document.getElementById("iclr-post-authors");
  let authorsText = authorsSpan.textContent;
  let lnameFnameInstitution = authorsText.split(";");
  let lfiList = lnameFnameInstitution.map(lfi => lfi.split(",").map(item => item.trim()));
  let bibtexLFI = lfiList.map(lfi => lfi[0] + ", " + lfi[1]).join(" and ")
  let academicLFI = lfiList.map(lfi => lfi[0]);
  {
    if(academicLFI.length > 2) academicLFI = academicLFI[0] + ", et al.";
    else if(academicLFI.length == 2) academicLFI = academicLFI[0] + " & " + academicLFI[1];
    else academicLFI = academicLFI[0];
  }

  let titleSpan = document.getElementById("iclr-post-title");
  let titleText = titleSpan.textContent.trim();
  let bibtexTitleShorthand = (lfiList[0][1]+
    "2022"+
    titleText.split(" ").slice(0, 3).join("")
  ).replace(" ", "").replace(/[\p{P}$+<=>^`|~]/gu, '').toLowerCase().trim();

  let bibtexTemplate = `
@inproceedings{${bibtexTitleShorthand}},
  author = {${bibtexLFI}},
  title = {${titleText}},
  booktitle = {ICLR Blog Track},
  year = {2022},
  note = {${window.location.href}},
  url  = {${window.location.href}}
}
  `.trim();
  document.getElementById("bibtex-box").innerText = bibtexTemplate;

  let academicTemplate = `
${academicLFI}, "${titleText}", ICLR Blog Track, 2022.
`.trim();
  document.getElementById("bibtex-academic-attribution").innerText = academicTemplate;

</script>


<div class="related">
  <h2>Related posts</h2>
  <ul class="related-posts">
    
      <li>
        <h3>
          <a href="/2021/09/01/sample-submission/">
            Sample Submission
            <small>01 Sep 2021 | 
    <a class="content-tag" href="/tags/#generative-modeling"> generative modeling </a>
  
    <a class="content-tag" href="/tags/#denoising-diffusion"> denoising diffusion </a>
  
    <a class="content-tag" href="/tags/#ddim"> DDIM </a>
  
    <a class="content-tag" href="/tags/#ddpm"> DDPM </a>
  </small>
          </a>
        </h3>
      </li>
    
      <li>
        <h3>
          <a href="/2020/04/02/example-content/">
            Example content (Basic Markdown)
            <small>02 Apr 2020 | 
    <a class="content-tag" href="/tags/#generative-modeling"> generative modeling </a>
  
    <a class="content-tag" href="/tags/#denoising-diffusion"> denoising diffusion </a>
  
    <a class="content-tag" href="/tags/#ddim"> DDIM </a>
  
    <a class="content-tag" href="/tags/#ddpm"> DDPM </a>
  </small>
          </a>
        </h3>
      </li>
    
  </ul>
</div>


<script src="https://utteranc.es/client.js"
        repo="iclr-blog-track/iclr-blog-track.github.io"
        issue-term="pathname"
        label="utterance"
        theme="boxy-light"
        crossorigin="anonymous"
        >
</script>


      </div>
    </div>

    <label for="sidebar-checkbox" class="sidebar-toggle"></label>

    <script src='/public/js/script.js'></script>
  </body>
</html>
