<!-- #region -->
<center><h1> The Annotated S4 </h1></center>

<center>
<p><a href="https://arxiv.org/abs/2111.00396">Efficiently Modeling Long Sequences with Structured State Spaces</a></p>
</center>

<center>
<p> Albert Gu, Karan Goel, and Christopher Ré.</p>
</center>
<p><img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/hero.png" width="100%" />
<!-- #endregion --></p>

<p><em>Blog Post and Library …</em></p>

<p>The <a href="https://arxiv.org/abs/2111.00396">Structured State Space for Sequence
Modeling</a> (S4) architecture is a new approach to very
long-range sequence modeling tasks for vision,
language, and audio, showing a capacity to capture dependencies over tens
of thousands of steps. Especially impressive are the model’s results on the challenging
<a href="https://github.com/google-research/long-range-arena">Long Range Arena</a> benchmark, showing an ability
to reason over sequences of up to <strong>16,000+</strong> elements with high accuracy.</p>

<p><img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/table.png" width="100%" /></p>

<p>The paper is also a refreshing departure from Transformers, taking
a very different approach to an important problem-space.  However,
several of our colleagues have also noted privately (and on
<a href="https://twitter.com/sleepinyourhat/status/1468037897446121483">twitter</a>!)
the difficulty of gaining intuition for the model.  This blog post is a first
step towards this goal of gaining intuition, linking concrete code implementations
with explanations from the S4 paper – very much in the style of <a href="https://nlp.seas.harvard.edu/2018/04/03/attention.html">the annotated
transformer</a>.
Hopefully this combination of code and literate explanations helps you follow the
details of the model.</p>

<h2 id="table-of-contents">Table of Contents</h2>

<ul>
  <li>Part 1: <a href="#part-1-state-space-models">State Space Models</a></li>
  <li>Part 2: <a href="#part-2-implementing-s4">Implementing S4</a></li>
  <li>Part 3: <a href="#part-3-s4-in-practice">S4 in Practice</a></li>
</ul>

<p>Note that this project uses <a href="https://github.com/google/jax/">JAX</a>
with the <a href="https://github.com/google/flax">Flax</a> NN library.  While
we personally mainly use Torch, the functional nature of JAX is a good
fit for some of the complexities of S4. We make heavy use of
<a href="https://jax.readthedocs.io/en/latest/jax.html#jax.vmap">vmap</a>,
<a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html">scan</a>,
their <a href="https://flax.readthedocs.io/en/latest/flax.linen.html#module-flax.linen.transforms">NN
cousins</a>,
and most importantly
<a href="https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables">jax.jit</a>
to compile fast and efficient S4 layers.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
<span class="kn">import</span> <span class="nn">jax</span>
<span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">from</span> <span class="nn">flax</span> <span class="kn">import</span> <span class="n">linen</span> <span class="k">as</span> <span class="n">nn</span>
<span class="kn">from</span> <span class="nn">jax.nn.initializers</span> <span class="kn">import</span> <span class="n">lecun_normal</span>
<span class="kn">from</span> <span class="nn">jax.numpy.linalg</span> <span class="kn">import</span> <span class="n">eig</span><span class="p">,</span> <span class="n">inv</span><span class="p">,</span> <span class="n">matrix_power</span>
<span class="kn">from</span> <span class="nn">jax.scipy.signal</span> <span class="kn">import</span> <span class="n">convolve</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">rng</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>

<h2 id="part-1-state-space-models">Part 1: State Space Models</h2>

<p>Let’s get started! Our goal is the efficient
modeling of long sequences. To do this, we are going to build a
new neural network layer based on State Space Models. By the end of
this section we will be able to build and run a model with this layer.
However, we are going to need some technical background. Let’s work
our way through the background of the paper.</p>

<blockquote>
  <p>The <a href="https://en.wikipedia.org/wiki/State-space_representation">state space model</a> is defined by this simple equation.
It maps a 1-D input signal $u(t)$ to an $N$-D latent state $x(t)$
before projecting to a 1-D output signal $y(t)$.</p>
</blockquote>

\[\begin{aligned}
    x'(t) &amp;= \boldsymbol{A}x(t) + \boldsymbol{B}u(t) \\
    y(t) &amp;= \boldsymbol{C}x(t) + \boldsymbol{D}u(t)
  \end{aligned}\]

<blockquote>
  <p>Our goal is
to simply use the SSM as a black-box representation in a deep
sequence model, where $\boldsymbol{A}, \boldsymbol{B}, \boldsymbol{C}, \boldsymbol{D}$ are
parameters learned by gradient descent.  For the remainder, we will
omit the parameter $\boldsymbol{D}$ for exposition (or equivalently,
assume $\boldsymbol{D} = 0$  because the term $\boldsymbol{D}u$ can be
viewed as a skip connection and is easy to compute).</p>

  <p>An SSM maps a input $u(t)$ to a state representation vector $x(t)$ and an output $y(t)$.
For simplicity, we assume the input and output are one-dimensional, and the state representation
is $N$-dimensional. The first equation defines the change in $x(t)$ over time.</p>
</blockquote>

<p>Our SSMs will be defined by three matrices – $\boldsymbol{A}, \boldsymbol{B}, \boldsymbol{C}$ – which
we will learn. For now we begin with a random SSM, to define sizes,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">random_SSM</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">N</span><span class="p">):</span>
    <span class="n">a_r</span><span class="p">,</span> <span class="n">b_r</span><span class="p">,</span> <span class="n">c_r</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
    <span class="n">A</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">a_r</span><span class="p">,</span> <span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">N</span><span class="p">))</span>
    <span class="n">B</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">b_r</span><span class="p">,</span> <span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
    <span class="n">C</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">c_r</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">N</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span>
</code></pre></div></div>

<h3 id="discrete-time-ssm-the-recurrent-representation">Discrete-time SSM: The Recurrent Representation</h3>

<blockquote>
  <p>To be applied on a discrete input sequence $(u_0, u_1, \dots )$
instead of continuous function $u(t)$, the SSM must be
discretized by a <strong>step size</strong> $\Delta$ that represents the
resolution of the input.  Conceptually, the inputs $u_k$ can be
viewed as sampling an implicit underlying continuous signal $u(t)$,
where $u_k = u(k \Delta)$.</p>

  <p>To discretize the continuous-time SSM, we use
the <a href="https://en.wikipedia.org/wiki/Bilinear_transform">bilinear method</a>, which converts the
state matrix $\boldsymbol{A}$ into an approximation $\boldsymbol{\overline{A}}$.  The discrete SSM is:</p>
</blockquote>

\[\begin{aligned}
  \boldsymbol{\overline{A}} &amp;= (\boldsymbol{I} - \Delta/2 \cdot \boldsymbol{A})^{-1}(\boldsymbol{I} + \Delta/2 \cdot \boldsymbol{A}) \\
  \boldsymbol{\overline{B}} &amp;= (\boldsymbol{I} - \Delta/2 \cdot \boldsymbol{A})^{-1} \Delta \boldsymbol{B} \\
  \boldsymbol{\overline{C}} &amp;= \boldsymbol{C}\\
\end{aligned}\]

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">discretize</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">step</span><span class="p">):</span>
    <span class="n">I</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">eye</span><span class="p">(</span><span class="n">A</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
    <span class="n">BL</span> <span class="o">=</span> <span class="n">inv</span><span class="p">(</span><span class="n">I</span> <span class="o">-</span> <span class="p">(</span><span class="n">step</span> <span class="o">/</span> <span class="mf">2.0</span><span class="p">)</span> <span class="o">*</span> <span class="n">A</span><span class="p">)</span>
    <span class="n">Ab</span> <span class="o">=</span> <span class="n">BL</span> <span class="o">@</span> <span class="p">(</span><span class="n">I</span> <span class="o">+</span> <span class="p">(</span><span class="n">step</span> <span class="o">/</span> <span class="mf">2.0</span><span class="p">)</span> <span class="o">*</span> <span class="n">A</span><span class="p">)</span>
    <span class="n">Bb</span> <span class="o">=</span> <span class="p">(</span><span class="n">BL</span> <span class="o">*</span> <span class="n">step</span><span class="p">)</span> <span class="o">@</span> <span class="n">B</span>
    <span class="k">return</span> <span class="n">Ab</span><span class="p">,</span> <span class="n">Bb</span><span class="p">,</span> <span class="n">C</span>
</code></pre></div></div>

<blockquote>
  <p>This equation is now a <em>sequence-to-sequence</em> map $u_k \mapsto y_k$ instead of function-to-function.
Moreover the state equation is now a recurrence in $x_k$, allowing the discrete SSM to be computed like an RNN.
Concretely, $x_k \in \mathbb{R}^N$ can be viewed as a <em>hidden state</em> with transition matrix $\boldsymbol{\overline{A}}$.</p>
</blockquote>

\[\begin{aligned}
  x_{k} &amp;= \boldsymbol{\overline{A}} x_{k-1} + \boldsymbol{\overline{B}} u_k\\
  y_k &amp;= \boldsymbol{\overline{C}} x_k \\
\end{aligned}\]

<p>As the paper says, this “step” function does look superficially like that of
an RNN. We can implement this with a
<a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html">scan</a>
in JAX,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">scan_SSM</span><span class="p">(</span><span class="n">Ab</span><span class="p">,</span> <span class="n">Bb</span><span class="p">,</span> <span class="n">Cb</span><span class="p">,</span> <span class="n">u</span><span class="p">,</span> <span class="n">x0</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="n">x_k_1</span><span class="p">,</span> <span class="n">u_k</span><span class="p">):</span>
        <span class="n">x_k</span> <span class="o">=</span> <span class="n">Ab</span> <span class="o">@</span> <span class="n">x_k_1</span> <span class="o">+</span> <span class="n">Bb</span> <span class="o">@</span> <span class="n">u_k</span>
        <span class="n">y_k</span> <span class="o">=</span> <span class="n">Cb</span> <span class="o">@</span> <span class="n">x_k</span>
        <span class="k">return</span> <span class="n">x_k</span><span class="p">,</span> <span class="n">y_k</span>

    <span class="k">return</span> <span class="n">jax</span><span class="p">.</span><span class="n">lax</span><span class="p">.</span><span class="n">scan</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">x0</span><span class="p">,</span> <span class="n">u</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span>
</code></pre></div></div>

<p>Putting everything together, we can run the SSM
by first discretizing, then iterating step by step,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">run_SSM</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">u</span><span class="p">):</span>
    <span class="n">L</span> <span class="o">=</span> <span class="n">u</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    <span class="n">N</span> <span class="o">=</span> <span class="n">A</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    <span class="n">Ab</span><span class="p">,</span> <span class="n">Bb</span><span class="p">,</span> <span class="n">Cb</span> <span class="o">=</span> <span class="n">discretize</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">/</span> <span class="n">L</span><span class="p">)</span>

    <span class="c1"># Run recurrence
</span>    <span class="k">return</span> <span class="n">scan_SSM</span><span class="p">(</span><span class="n">Ab</span><span class="p">,</span> <span class="n">Bb</span><span class="p">,</span> <span class="n">Cb</span><span class="p">,</span> <span class="n">u</span><span class="p">[:,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">],</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">N</span><span class="p">,)))</span>
</code></pre></div></div>

<h3 id="tangent-a-mechanics-example">Tangent: A Mechanics Example</h3>

<p>To gain some more intuition and test our SSM implementation, we pause
 from machine learning to implement a <a href="https://en.wikipedia.org/wiki/State-space_representation#Moving_object_example">classic example from mechanics</a>.</p>

<p>In this example, we consider the forward position $y(t)$ of a mass attached to a wall with a spring.
 Over time, varying force $u(t)$ is applied to this mass. The system is parameterized by mass ($m$),
 spring constant ($k$), friction constant ($b$). We can relate these with the following differential equation:</p>

\[\begin{aligned}
my''(t) = u(t) - by'(t) - ky(t)
\end{aligned}\]

<p>Rewriting this in matrix form yields an SSM in the following form:</p>

\[\begin{aligned}
\boldsymbol{A} &amp;= \begin{bmatrix} 0 &amp; 1 \\ -k/m &amp; -b/m \end{bmatrix}  \\
\boldsymbol{B} &amp; = \begin{bmatrix} 0  \\ 1/m \end{bmatrix} &amp; \boldsymbol{C} = \begin{bmatrix} 1 &amp; 0  \end{bmatrix}  \\
\end{aligned}\]

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">example_mass</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">m</span><span class="p">):</span>
    <span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="o">-</span><span class="n">k</span> <span class="o">/</span> <span class="n">m</span><span class="p">,</span> <span class="o">-</span><span class="n">b</span> <span class="o">/</span> <span class="n">m</span><span class="p">]])</span>
    <span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([[</span><span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mf">1.0</span> <span class="o">/</span> <span class="n">m</span><span class="p">]])</span>
    <span class="n">C</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]])</span>
    <span class="k">return</span> <span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span>
</code></pre></div></div>

<p>Looking at the $\boldsymbol{C}$, we should be able to convince ourselves that the
 first dimension of the hidden state is the position (since that becomes $y(t)$).
 The second dimension is the velocity, as it is impacted by $u(t)$ through
 $\boldsymbol{B}$. The transition $\boldsymbol{A}$ relates these terms.</p>

<p>We’ll set $u$ to be a continuous function of $t$,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">partial</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">vectorize</span><span class="p">,</span> <span class="n">signature</span><span class="o">=</span><span class="s">"()-&gt;()"</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">example_force</span><span class="p">(</span><span class="n">t</span><span class="p">):</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="mi">10</span> <span class="o">*</span> <span class="n">t</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="p">(</span><span class="n">x</span> <span class="o">&gt;</span> <span class="mf">0.5</span><span class="p">)</span>
</code></pre></div></div>

<p>Let’s run this SSM through our code.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">example_ssm</span><span class="p">():</span>
    <span class="c1"># SSM
</span>    <span class="n">ssm</span> <span class="o">=</span> <span class="n">example_mass</span><span class="p">(</span><span class="n">k</span><span class="o">=</span><span class="mi">40</span><span class="p">,</span> <span class="n">b</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">m</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

    <span class="c1"># L samples of u(t).
</span>    <span class="n">L</span> <span class="o">=</span> <span class="mi">100</span>
    <span class="n">step</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">L</span>
    <span class="n">ks</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="n">L</span><span class="p">)</span>
    <span class="n">u</span> <span class="o">=</span> <span class="n">example_force</span><span class="p">(</span><span class="n">ks</span> <span class="o">*</span> <span class="n">step</span><span class="p">)</span>

    <span class="c1"># Approximation of y(t).
</span>    <span class="n">y</span> <span class="o">=</span> <span class="n">run_SSM</span><span class="p">(</span><span class="o">*</span><span class="n">ssm</span><span class="p">,</span> <span class="n">u</span><span class="p">)</span>

    <span class="c1"># Plotting ---
</span>    <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
    <span class="kn">import</span> <span class="nn">seaborn</span>
    <span class="kn">from</span> <span class="nn">celluloid</span> <span class="kn">import</span> <span class="n">Camera</span>

    <span class="n">seaborn</span><span class="p">.</span><span class="n">set_context</span><span class="p">(</span><span class="s">"paper"</span><span class="p">)</span>
    <span class="n">fig</span><span class="p">,</span> <span class="p">(</span><span class="n">ax1</span><span class="p">,</span> <span class="n">ax2</span><span class="p">,</span> <span class="n">ax3</span><span class="p">)</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
    <span class="n">camera</span> <span class="o">=</span> <span class="n">Camera</span><span class="p">(</span><span class="n">fig</span><span class="p">)</span>
    <span class="n">ax1</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Force $u_k$"</span><span class="p">)</span>
    <span class="n">ax2</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Position $y_k$"</span><span class="p">)</span>
    <span class="n">ax3</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Object"</span><span class="p">)</span>
    <span class="n">ax1</span><span class="p">.</span><span class="n">set_xticks</span><span class="p">([],</span> <span class="p">[])</span>
    <span class="n">ax2</span><span class="p">.</span><span class="n">set_xticks</span><span class="p">([],</span> <span class="p">[])</span>

    <span class="c1"># Animate plot over time
</span>    <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="mi">2</span><span class="p">):</span>
        <span class="n">ax1</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">ks</span><span class="p">[:</span><span class="n">k</span><span class="p">],</span> <span class="n">u</span><span class="p">[:</span><span class="n">k</span><span class="p">],</span> <span class="n">color</span><span class="o">=</span><span class="s">"red"</span><span class="p">)</span>
        <span class="n">ax2</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">ks</span><span class="p">[:</span><span class="n">k</span><span class="p">],</span> <span class="n">y</span><span class="p">[:</span><span class="n">k</span><span class="p">],</span> <span class="n">color</span><span class="o">=</span><span class="s">"blue"</span><span class="p">)</span>
        <span class="n">ax3</span><span class="p">.</span><span class="n">boxplot</span><span class="p">(</span>
            <span class="p">[[</span><span class="n">y</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="mf">0.04</span><span class="p">,</span> <span class="n">y</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">y</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mf">0.04</span><span class="p">]],</span>
            <span class="n">showcaps</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
            <span class="n">whis</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
            <span class="n">vert</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
            <span class="n">widths</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
        <span class="p">)</span>
        <span class="n">camera</span><span class="p">.</span><span class="n">snap</span><span class="p">()</span>
    <span class="n">anim</span> <span class="o">=</span> <span class="n">camera</span><span class="p">.</span><span class="n">animate</span><span class="p">()</span>
    <span class="n">anim</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="s">"line.gif"</span><span class="p">,</span> <span class="n">dpi</span><span class="o">=</span><span class="mi">150</span><span class="p">,</span> <span class="n">writer</span><span class="o">=</span><span class="s">"imagemagick"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">example_ssm</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/line.gif" width="100%" /></p>

<p>Neat! And that it was just 1 SSM, with 2 hidden states over 100 steps.
The final model will have had <strong>100s of stacked SSMs</strong> over <strong>thousands of steps</strong>. But first – we
need to make these models practical to train.</p>

<h3 id="training-ssms-the-convolutional-representation">Training SSMs: The Convolutional Representation</h3>

<p>The punchline of this section is that we can turn the “RNN” above into a “CNN”
by unrolling. Let’s go through the derivation.</p>

<blockquote>
  <p>The recurrent SSM is not practical for training on modern hardware
due to its sequential nature.  Instead, there is a well-known connection
between linear time-invariant (LTI) SSMs and
continuous convolutions.  Correspondingly, the recurrent SSM can actually be
written as a <a href="https://en.wikipedia.org/wiki/Convolution#Discrete_convolution">discrete convolution</a>.</p>

  <p>For simplicity let the initial state be $x_{-1} = 0$. Then unrolling  explicitly yields:</p>

\[\begin{aligned}
  x_0 &amp;= \boldsymbol{\overline{B}} u_0 &amp;
  x_1 &amp;= \boldsymbol{\overline{A}} \boldsymbol{\overline{B}} u_0 + \boldsymbol{\overline{B}} u_1 &amp;
  x_2 &amp;= \boldsymbol{\overline{A}}^2 \boldsymbol{\overline{B}} u_0 + \boldsymbol{\overline{A}} \boldsymbol{\overline{B}} u_1 + \boldsymbol{\overline{B}} u_2 &amp; \dots
  \\
  y_0 &amp;= \boldsymbol{\overline{C}} \boldsymbol{\overline{B}} u_0 &amp;
  y_1 &amp;= \boldsymbol{\overline{C}} \boldsymbol{\overline{A}} \boldsymbol{\overline{B}} u_0 + \boldsymbol{\overline{C}} \boldsymbol{\overline{B}} u_1 &amp;
  y_2 &amp;= \boldsymbol{\overline{C}} \boldsymbol{\overline{A}}^2 \boldsymbol{\overline{B}} u_0 + \boldsymbol{\overline{C}} \boldsymbol{\overline{A}} \boldsymbol{\overline{B}} u_1 + \boldsymbol{\overline{C}} \boldsymbol{\overline{B}} u_2
  &amp; \dots
\end{aligned}\]

  <p>This can be vectorized into a convolution with an explicit formula for the convolution kernel.</p>

\[\begin{aligned}
    y_k &amp;= \boldsymbol{\overline{C}} \boldsymbol{\overline{A}}^k \boldsymbol{\overline{B}} u_0 + \boldsymbol{\overline{C}} \boldsymbol{\overline{A}}^{k-1} \boldsymbol{\overline{B}} u_1 + \dots + \boldsymbol{\overline{C}} \boldsymbol{\overline{A}} \boldsymbol{\overline{B}} u_{k-1} + \boldsymbol{\overline{C}}\boldsymbol{\overline{B}} u_k
    \\
    y &amp;= \boldsymbol{\overline{K}} \ast u
\end{aligned}\]
</blockquote>

\[\begin{aligned}
  \boldsymbol{\overline{K}} \in \mathbb{R}^L  = (\boldsymbol{\overline{C}}\boldsymbol{\overline{B}}, \boldsymbol{\overline{C}}\boldsymbol{\overline{A}}\boldsymbol{\overline{B}}, \dots, \boldsymbol{\overline{C}}\boldsymbol{\overline{A}}^{L-1}\boldsymbol{\overline{B}})
\end{aligned}\]

<p>We call $\boldsymbol{\overline{K}}$ the <strong>SSM convolution kernel</strong> or filter.</p>

<p>Note that this is a <em>giant</em> filter. It is the size of the entire sequence!</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">K_conv</span><span class="p">(</span><span class="n">Ab</span><span class="p">,</span> <span class="n">Bb</span><span class="p">,</span> <span class="n">Cb</span><span class="p">,</span> <span class="n">L</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span>
        <span class="p">[(</span><span class="n">Cb</span> <span class="o">@</span> <span class="n">matrix_power</span><span class="p">(</span><span class="n">Ab</span><span class="p">,</span> <span class="n">l</span><span class="p">)</span> <span class="o">@</span> <span class="n">Bb</span><span class="p">).</span><span class="n">reshape</span><span class="p">()</span> <span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">L</span><span class="p">)]</span>
    <span class="p">)</span>
</code></pre></div></div>

<p>Warning: this implementation is naive and unstable. In practice it will fail to work
for more than very small lengths. However, we are going to replace it with S4 in Part 2, so for
now we just keep it around as a placeholder.</p>

<p>We can compute the result of applying this filter either with a standard direct convolution or
with a padded (non-circular) <a href="https://en.wikipedia.org/wiki/Convolution_theorem">Fast Fourier Transform (FFT)</a>.
As the length gets longer the second method will be more efficient,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">non_circular_convolution</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">nofft</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
    <span class="k">if</span> <span class="n">nofft</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">convolve</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s">"full"</span><span class="p">)[:</span> <span class="n">u</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="k">assert</span> <span class="n">K</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">u</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
        <span class="n">ud</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">fft</span><span class="p">.</span><span class="n">rfft</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">pad</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">K</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])))</span>
        <span class="n">Kd</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">fft</span><span class="p">.</span><span class="n">rfft</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">pad</span><span class="p">(</span><span class="n">K</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">u</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])))</span>
        <span class="n">out</span> <span class="o">=</span> <span class="n">ud</span> <span class="o">*</span> <span class="n">Kd</span>
        <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">fft</span><span class="p">.</span><span class="n">irfft</span><span class="p">(</span><span class="n">out</span><span class="p">)[:</span> <span class="n">u</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span>
</code></pre></div></div>

<p>The CNN method and the RNN method yield (roughly) the same result,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">test_cnn_is_rnn</span><span class="p">(</span><span class="n">N</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">L</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">/</span> <span class="mi">16</span><span class="p">):</span>
    <span class="n">ssm</span> <span class="o">=</span> <span class="n">random_SSM</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">N</span><span class="p">)</span>
    <span class="n">u</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="p">(</span><span class="n">L</span><span class="p">,))</span>

    <span class="c1"># "RNN"
</span>    <span class="n">rec</span> <span class="o">=</span> <span class="n">run_SSM</span><span class="p">(</span><span class="o">*</span><span class="n">ssm</span><span class="p">,</span> <span class="n">u</span><span class="p">)</span>

    <span class="c1"># "CNN"
</span>    <span class="n">ssmb</span> <span class="o">=</span> <span class="n">discretize</span><span class="p">(</span><span class="o">*</span><span class="n">ssm</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="n">step</span><span class="p">)</span>
    <span class="n">conv</span> <span class="o">=</span> <span class="n">non_circular_convolution</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">K_conv</span><span class="p">(</span><span class="o">*</span><span class="n">ssmb</span><span class="p">,</span> <span class="n">L</span><span class="p">))</span>

    <span class="c1"># Check
</span>    <span class="k">assert</span> <span class="n">np</span><span class="p">.</span><span class="n">isclose</span><span class="p">(</span><span class="n">rec</span><span class="p">.</span><span class="n">ravel</span><span class="p">(),</span> <span class="n">conv</span><span class="p">.</span><span class="n">ravel</span><span class="p">(),</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">).</span><span class="nb">all</span><span class="p">()</span>
</code></pre></div></div>

<p>At this point we have all of the machinery used for SSM training. The next
steps are about 1) making these models stable to train, and 2) making them fast.</p>

<h3 id="addressing-long-range-dependencies-with-hippo">Addressing Long-Range Dependencies with HiPPO</h3>

<p><img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/hippo.png" width="100%" /></p>

<blockquote>
  <p><a href="https://arxiv.org/abs/2008.07669">Prior work</a> found that the basic SSM actually performs very poorly in
practice.  Intuitively, one explanation is that they  suffer from gradients scaling exponentially in the sequence length (i.e., the
vanishing/exploding gradients problem).  To address this problem, previous work developed the HiPPO theory of
continuous-time memorization.</p>

  <p>HiPPO specifies a class of certain matrices $\boldsymbol{A} \in \mathbb{R}^{N \times N}$ that when incorporated,
allow the state $x(t)$ to memorize the history of the input $u(t)$.
The most important matrix in this class is defined by the HiPPO matrix.</p>

\[\begin{aligned}
  (\text{HiPPO Matrix})
  \qquad
  \boldsymbol{A}_{nk}
  =
  \begin{cases}
    (2n+1)^{1/2}(2k+1)^{1/2} &amp; \text{if } n &gt; k \\
    n+1 &amp; \text{if } n = k \\
    0 &amp; \text{if } n &lt; k
  \end{cases}
\end{aligned}\]

  <p>Previous work found that simply modifying an SSM from a random matrix $\boldsymbol{A}$ to HiPPO
improved its performance on the sequential MNIST classification benchmark from $50\%$ to $98\%$.</p>
</blockquote>

<p>This matrix is going to be really important, but it is a bit of
magic. For our purposes we mainly need to know that: 1) we only need to
calculate it once, and 2) it has a nice, simple structure (which we will exploit in
part 2). Without going into the ODE math, the main takeaway
is that this matrix aims to remember the past history in the state a
timescale invariant manner,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">make_HiPPO</span><span class="p">(</span><span class="n">N</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">v</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">):</span>
        <span class="k">if</span> <span class="n">n</span> <span class="o">&gt;</span> <span class="n">k</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">n</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">k</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
        <span class="k">elif</span> <span class="n">n</span> <span class="o">==</span> <span class="n">k</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">n</span> <span class="o">+</span> <span class="mi">1</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">return</span> <span class="mi">0</span>

    <span class="c1"># Do it slow so we don't mess it up :)
</span>    <span class="n">mat</span> <span class="o">=</span> <span class="p">[[</span><span class="n">v</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">N</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">N</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span>
    <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">mat</span><span class="p">)</span>
</code></pre></div></div>

<p>Diving a bit deeper, the intuitive explanation of this matrix is
that it produces a hidden state that memorizes its history. It does
this by keeping track of the coefficients of a <a href="https://en.wikipedia.org/wiki/Legendre_polynomials">Legendre
polynomial</a>. These
coefficients let it approximate all of the previous history. Let us
look at an example,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">example_legendre</span><span class="p">(</span><span class="n">N</span><span class="o">=</span><span class="mi">8</span><span class="p">):</span>
    <span class="c1"># Random hidden state as coefficients
</span>    <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
    <span class="kn">import</span> <span class="nn">numpy.polynomial.legendre</span>

    <span class="n">x</span> <span class="o">=</span> <span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">rand</span><span class="p">(</span><span class="n">N</span><span class="p">)</span> <span class="o">-</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span>
    <span class="n">t</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
    <span class="n">f</span> <span class="o">=</span> <span class="n">numpy</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">legendre</span><span class="p">.</span><span class="n">Legendre</span><span class="p">(</span><span class="n">x</span><span class="p">)(</span><span class="n">t</span><span class="p">)</span>

    <span class="c1"># Plot
</span>    <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
    <span class="kn">import</span> <span class="nn">seaborn</span>

    <span class="n">seaborn</span><span class="p">.</span><span class="n">set_context</span><span class="p">(</span><span class="s">"talk"</span><span class="p">)</span>
    <span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
    <span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="p">.</span><span class="n">gca</span><span class="p">(</span><span class="n">projection</span><span class="o">=</span><span class="s">"3d"</span><span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
        <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">25</span><span class="p">,</span> <span class="p">(</span><span class="n">N</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="mi">100</span> <span class="o">+</span> <span class="mi">25</span><span class="p">,</span> <span class="mi">100</span><span class="p">),</span>
        <span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="mi">100</span><span class="p">,</span>
        <span class="n">zs</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span>
        <span class="n">zdir</span><span class="o">=</span><span class="s">"x"</span><span class="p">,</span>
        <span class="n">color</span><span class="o">=</span><span class="s">"black"</span><span class="p">,</span>
    <span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">zs</span><span class="o">=</span><span class="n">N</span> <span class="o">*</span> <span class="mi">100</span><span class="p">,</span> <span class="n">zdir</span><span class="o">=</span><span class="s">"y"</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s">"r"</span><span class="p">)</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">N</span><span class="p">):</span>
        <span class="n">coef</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">N</span>
        <span class="n">coef</span><span class="p">[</span><span class="n">N</span> <span class="o">-</span> <span class="n">i</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_zlim</span><span class="p">(</span><span class="o">-</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_yticks</span><span class="p">([])</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_zticks</span><span class="p">([])</span>
        <span class="c1"># Plot basis function.
</span>        <span class="n">f</span> <span class="o">=</span> <span class="n">numpy</span><span class="p">.</span><span class="n">polynomial</span><span class="p">.</span><span class="n">legendre</span><span class="p">.</span><span class="n">Legendre</span><span class="p">(</span><span class="n">coef</span><span class="p">)(</span><span class="n">t</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">bar</span><span class="p">(</span>
            <span class="p">[</span><span class="mi">100</span> <span class="o">*</span> <span class="n">i</span><span class="p">],</span>
            <span class="p">[</span><span class="n">x</span><span class="p">[</span><span class="n">i</span><span class="p">]],</span>
            <span class="n">zs</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span>
            <span class="n">zdir</span><span class="o">=</span><span class="s">"x"</span><span class="p">,</span>
            <span class="n">label</span><span class="o">=</span><span class="s">"x%d"</span> <span class="o">%</span> <span class="n">i</span><span class="p">,</span>
            <span class="n">color</span><span class="o">=</span><span class="s">"brown"</span><span class="p">,</span>
            <span class="n">fill</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
            <span class="n">width</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span>
        <span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">zs</span><span class="o">=</span><span class="mi">100</span> <span class="o">*</span> <span class="n">i</span><span class="p">,</span> <span class="n">zdir</span><span class="o">=</span><span class="s">"y"</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s">"b"</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">view_init</span><span class="p">(</span><span class="n">elev</span><span class="o">=</span><span class="mf">40.0</span><span class="p">,</span> <span class="n">azim</span><span class="o">=-</span><span class="mi">45</span><span class="p">)</span>
    <span class="n">fig</span><span class="p">.</span><span class="n">savefig</span><span class="p">(</span><span class="s">"images/leg.png"</span><span class="p">)</span>
</code></pre></div></div>

<p>The red line represents that curve we are approximating,
while the black bars represent the values of our hidden state.
Each is a coefficient for one element of the Legendre series
shown as blue functions. The intuition is that the HiPPO matrix
updates these coefficients each step.</p>

<p><img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/leg.png" width="100%" /></p>

<h3 id="an-ssm-neural-network">An SSM Neural Network.</h3>

<p>We now have everything we need to build an SSM neural network layer.
As defined above, the discrete SSM defines a map from $\mathbb{R}^L
\to \mathbb{R}^L$, i.e. a 1-D sequence map. We assume that we
are going to be learning the parameters $B$ and $C$, as well as a
step size $\Delta$ and a scalar $D$ parameter. The HiPPO matrix is
used for the transition $A$. We learn the step size in log space.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">log_step_initializer</span><span class="p">(</span><span class="n">dt_min</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">dt_max</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span>
            <span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">dt_max</span><span class="p">)</span> <span class="o">-</span> <span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">dt_min</span><span class="p">)</span>
        <span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">dt_min</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">init</span>
</code></pre></div></div>

<p>For the SMM layer most of the work is to build the filter.
The actual call to the network is just the (huge) convolution we specified above.</p>

<p>Note for Torch users: <code class="language-plaintext highlighter-rouge">setup</code> in Flax is called each time the parameters are updated.
This is similar to the
<a href="https://pytorch.org/tutorials/intermediate/parametrizations.html">Torch parameterizations</a>.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">SSMLayer</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="n">A</span><span class="p">:</span> <span class="n">np</span><span class="p">.</span><span class="n">DeviceArray</span>  <span class="c1"># HiPPO
</span>    <span class="n">N</span><span class="p">:</span> <span class="nb">int</span>
    <span class="n">l_max</span><span class="p">:</span> <span class="nb">int</span>

    <span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="c1"># SSM parameters
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">B</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">param</span><span class="p">(</span><span class="s">"B"</span><span class="p">,</span> <span class="n">lecun_normal</span><span class="p">(),</span> <span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">N</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">C</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">param</span><span class="p">(</span><span class="s">"C"</span><span class="p">,</span> <span class="n">lecun_normal</span><span class="p">(),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">N</span><span class="p">))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">D</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">param</span><span class="p">(</span><span class="s">"D"</span><span class="p">,</span> <span class="n">nn</span><span class="p">.</span><span class="n">initializers</span><span class="p">.</span><span class="n">ones</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,))</span>

        <span class="c1"># Step parameter
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">log_step</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">param</span><span class="p">(</span><span class="s">"log_step"</span><span class="p">,</span> <span class="n">log_step_initializer</span><span class="p">(),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,))</span>

        <span class="n">step</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">log_step</span><span class="p">)</span>
        <span class="n">ssm</span> <span class="o">=</span> <span class="n">discretize</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">A</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">B</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">C</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="n">step</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">K</span> <span class="o">=</span> <span class="n">K_conv</span><span class="p">(</span><span class="o">*</span><span class="n">ssm</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">l_max</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">u</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">non_circular_convolution</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">K</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">D</span> <span class="o">*</span> <span class="n">u</span>
</code></pre></div></div>

<p>Since our SSMs operate on scalars, we make $H$ different, stacked copies ($H$ different SSMs!) with
different parameters. Here we use the <a href="https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.vmap.html">Flax vmap</a>
method to easily define these copies,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">cloneLayer</span><span class="p">(</span><span class="n">layer</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">nn</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span>
        <span class="n">layer</span><span class="p">,</span>
        <span class="n">in_axes</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
        <span class="n">out_axes</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
        <span class="n">variable_axes</span><span class="o">=</span><span class="p">{</span><span class="s">"params"</span><span class="p">:</span> <span class="mi">1</span><span class="p">},</span>
        <span class="n">split_rngs</span><span class="o">=</span><span class="p">{</span><span class="s">"params"</span><span class="p">:</span> <span class="bp">True</span><span class="p">},</span>
    <span class="p">)</span>
</code></pre></div></div>

<p>We then initialize $A$ with the HiPPO matrix, and pass it into the stack of modules above,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">SSMInit</span><span class="p">(</span><span class="n">N</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">partial</span><span class="p">(</span><span class="n">cloneLayer</span><span class="p">(</span><span class="n">SSMLayer</span><span class="p">),</span> <span class="n">A</span><span class="o">=</span><span class="n">make_HiPPO</span><span class="p">(</span><span class="n">N</span><span class="p">),</span> <span class="n">N</span><span class="o">=</span><span class="n">N</span><span class="p">)</span>
</code></pre></div></div>

<p>This SSM Layer can then be put into a standard NN. For instance, here
we have a Transformer-style stack of residual blocks, each containing the $H$ stacked SSMs.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">SeqInternal</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="n">layer</span><span class="p">:</span> <span class="n">nn</span><span class="p">.</span><span class="n">Module</span>
    <span class="n">l_max</span><span class="p">:</span> <span class="nb">int</span>
    <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span>
    <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span>
    <span class="n">training</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="bp">True</span>

    <span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">seq</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">layer</span><span class="p">(</span><span class="n">l_max</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">l_max</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">LayerNorm</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">out</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">d_model</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">drop</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span>
            <span class="bp">self</span><span class="p">.</span><span class="n">dropout</span><span class="p">,</span>
            <span class="n">broadcast_dims</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
            <span class="n">deterministic</span><span class="o">=</span><span class="ow">not</span> <span class="bp">self</span><span class="p">.</span><span class="n">training</span><span class="p">,</span>
        <span class="p">)</span>

    <span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">blank</span><span class="p">):</span>
        <span class="n">x2</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">seq</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">drop</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">out</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">drop</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">gelu</span><span class="p">(</span><span class="n">x2</span><span class="p">))))</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">z</span> <span class="o">+</span> <span class="n">x</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">SeqModel</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="n">layer</span><span class="p">:</span> <span class="n">nn</span><span class="p">.</span><span class="n">Module</span>
    <span class="n">d_output</span><span class="p">:</span> <span class="nb">int</span>
    <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span>
    <span class="n">l_max</span><span class="p">:</span> <span class="nb">int</span>
    <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span>
    <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.2</span>
    <span class="n">training</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="bp">True</span>
    <span class="n">classification</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="bp">False</span>

    <span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">d_model</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">d_output</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">layers</span> <span class="o">=</span> <span class="p">[</span>
            <span class="n">SeqInternal</span><span class="p">(</span>
                <span class="n">layer</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">layer</span><span class="p">,</span>
                <span class="n">d_model</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">d_model</span><span class="p">,</span>
                <span class="n">dropout</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">dropout</span><span class="p">,</span>
                <span class="n">training</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">training</span><span class="p">,</span>
                <span class="n">l_max</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">l_max</span><span class="p">,</span>
            <span class="p">)</span>
            <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">n_layers</span><span class="p">)</span>
        <span class="p">]</span>

    <span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">layers</span><span class="p">:</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">None</span><span class="p">)</span>
        <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">classification</span><span class="p">:</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">nn</span><span class="p">.</span><span class="n">log_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">BatchSeqModel</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span>
    <span class="n">SeqModel</span><span class="p">,</span>
    <span class="n">in_axes</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
    <span class="n">out_axes</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
    <span class="n">variable_axes</span><span class="o">=</span><span class="p">{</span><span class="s">"params"</span><span class="p">:</span> <span class="bp">None</span><span class="p">,</span> <span class="s">"dropout"</span><span class="p">:</span> <span class="bp">None</span><span class="p">},</span>
    <span class="n">split_rngs</span><span class="o">=</span><span class="p">{</span><span class="s">"params"</span><span class="p">:</span> <span class="bp">False</span><span class="p">,</span> <span class="s">"dropout"</span><span class="p">:</span> <span class="bp">True</span><span class="p">},</span>
<span class="p">)</span>
</code></pre></div></div>

<p>Overall, this defines a sequence-to-sequence map of shape (batch size, sequence length, hidden dimension),
exactly the signature exposed by related sequence models such as Transformers, RNNs, and CNNs.</p>

<p>While
we now have our main model, it is not fast enough to actually use. The next
section is all about making this SSM Layer faster – a lot faster!</p>

<h2 id="part-2-implementing-s4">Part 2: Implementing S4</h2>

<p>Warning: this section has a lot of math. Roughly it boils down to finding a
way to compute the filter from Part 1 with a “HiPPO-like” matrix <em>really
fast</em>. If you are interested, the details are really neat. If not,
skip to Part 3 for some cool applications like MNIST completion.</p>

<p><a href="#part-3-s4-in-practice">Skip Button</a></p>

<blockquote>
  <p>The fundamental bottleneck in computing the discrete-time SSM
is that it involves repeated matrix multiplication by
$\boldsymbol{\overline{A}}$.  For example, computing
naively  involves $L$ successive multiplications
by $\boldsymbol{\overline{A}}$, requiring $O(N^2 L)$ operations and
$O(NL)$ space.</p>
</blockquote>

<p>Specifically, recall this function here:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">K_conv_</span><span class="p">(</span><span class="n">Ab</span><span class="p">,</span> <span class="n">Bb</span><span class="p">,</span> <span class="n">Cb</span><span class="p">,</span> <span class="n">L</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span>
        <span class="p">[(</span><span class="n">Cb</span> <span class="o">@</span> <span class="n">matrix_power</span><span class="p">(</span><span class="n">Ab</span><span class="p">,</span> <span class="n">l</span><span class="p">)</span> <span class="o">@</span> <span class="n">Bb</span><span class="p">).</span><span class="n">reshape</span><span class="p">()</span> <span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">L</span><span class="p">)]</span>
    <span class="p">)</span>
</code></pre></div></div>

<p>The contribution of S4 is a stable method for speeding up this particular operation.
To do this we are going to focus on the case where the SSM
has special structure. Specifically, Diagonal Plus Low-Rank (DPLR) in complex
space.</p>

<!-- #region -->
<p><strong>DPLR:</strong> SSM is  $(\boldsymbol{\Lambda} - \boldsymbol{p}\boldsymbol{q}^*, \boldsymbol{B}, \boldsymbol{C})$ for some diagonal $\boldsymbol{\Lambda}$ and vectors $\boldsymbol{p}, \boldsymbol{q}, \boldsymbol{B}, \boldsymbol{C} \in \mathbb{C}^{N \times 1}$.</p>

<p>Under this DPLR assumption, S4 overcomes the speed bottleneck in three steps
<!-- #endregion --></p>

<ol>
  <li>Instead of computing $\boldsymbol{\overline{K}}$ directly,
we compute its spectrum by evaluating its <strong><a href="https://en.wikipedia.org/wiki/Generating_function">truncated generating function</a></strong> .  This  now involves a matrix <em>inverse</em> instead of <em>power</em>.</li>
  <li>We show that the diagonal matrix case is equivalent to the computation of a <strong><a href="https://en.wikipedia.org/wiki/Cauchy_matrix">Cauchy kernel</a></strong> $\frac{1}{\omega_j - \zeta_k}$.</li>
  <li>We show the low-rank term can now be corrected by applying the <strong><a href="https://en.wikipedia.org/wiki/Woodbury_matrix_identity">Woodbury identity</a></strong> which reduces $(\boldsymbol{\Lambda} + \boldsymbol{p}\boldsymbol{q}^*)^{-1}$ in terms of $\boldsymbol{\Lambda}^{-1}$, truly reducing to the diagonal case.</li>
</ol>

<h3 id="step-1-ssm-generating-functions">Step 1. SSM Generating Functions</h3>

<p>The main step will be switching from computing the sequence to computing its generating function.
From the paper’s appendix:</p>

<blockquote>
  <p>To address the problem of computing powers of $\boldsymbol{\overline{A}}$, we introduce another technique.
Instead of computing the SSM convolution filter $\boldsymbol{\overline{K}}$ directly,
we introduce a generating function on its coefficients and compute evaluations of it.</p>

  <p>The <em>truncated SSM generating function</em> at node $z$ with truncation $L$ is
\(\hat{\mathcal{K}}_L(z; \boldsymbol{\overline{A}}, \boldsymbol{\overline{B}}, \boldsymbol{\overline{C}}) \in \mathbb{C} := \sum_{i=0}^{L-1} \boldsymbol{\overline{C}} \boldsymbol{\overline{A}}^i \boldsymbol{\overline{B}} z^i\)</p>
</blockquote>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">K_gen_simple</span><span class="p">(</span><span class="n">Ab</span><span class="p">,</span> <span class="n">Bb</span><span class="p">,</span> <span class="n">Cb</span><span class="p">,</span> <span class="n">L</span><span class="p">):</span>
    <span class="n">K</span> <span class="o">=</span> <span class="n">K_conv</span><span class="p">(</span><span class="n">Ab</span><span class="p">,</span> <span class="n">Bb</span><span class="p">,</span> <span class="n">Cb</span><span class="p">,</span> <span class="n">L</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">gen</span><span class="p">(</span><span class="n">z</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">K</span> <span class="o">*</span> <span class="p">(</span><span class="n">z</span> <span class="o">**</span> <span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="n">L</span><span class="p">)))</span>

    <span class="k">return</span> <span class="n">gen</span>
</code></pre></div></div>

<blockquote>
  <p>The generating function essentially converts the SSM convolution filter from the time domain to
frequency domain. Importantly, it preserves the same information, and the desired SSM convolution filter
can be recovered from evaluations of its
<a href="https://math.stackexchange.com/questions/3213142/root-of-unity-filter">generating function at the roots of unity</a>
$\Omega = { \exp(2\pi \frac{k}{L} : k \in [L] }$ stably in $O(L \log L)$ operations by applying an
<a href="https://en.wikipedia.org/wiki/Fast_Fourier_transform">FFT</a>,</p>
</blockquote>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">conv_from_gen</span><span class="p">(</span><span class="n">gen</span><span class="p">,</span> <span class="n">L</span><span class="p">):</span>
    <span class="c1"># Evaluate at roots of unity
</span>    <span class="n">Omega_L</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">((</span><span class="mf">2j</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">pi</span> <span class="o">/</span> <span class="n">L</span><span class="p">)</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="n">L</span><span class="p">))</span>
    <span class="n">atRoots</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">gen</span><span class="p">)(</span><span class="n">Omega_L</span><span class="p">)</span>
    <span class="c1"># Inverse FFT
</span>    <span class="n">out</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">fft</span><span class="p">.</span><span class="n">ifft</span><span class="p">(</span><span class="n">atRoots</span><span class="p">,</span> <span class="n">L</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="n">L</span><span class="p">)</span>
    <span class="c1"># Numpy returns the values out of order.
</span>    <span class="n">order</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="n">i</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">L</span> <span class="o">-</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">L</span><span class="p">)])</span>
    <span class="k">return</span> <span class="n">out</span><span class="p">[</span><span class="n">order</span><span class="p">].</span><span class="n">real</span>
</code></pre></div></div>

<p>More importantly, in the generating function we can replace the matrix power with an inverse!
\(\hat{\mathcal{K}}_L(z) = \sum_{i=0}^{L-1} \boldsymbol{\overline{C}} \boldsymbol{\overline{A}}^i \boldsymbol{\overline{B}} z^i = \boldsymbol{\overline{C}} (\boldsymbol{I} - \boldsymbol{\overline{A}}^L z^L) (\boldsymbol{I} - \boldsymbol{\overline{A}} z)^{-1} \boldsymbol{\overline{B}} = \boldsymbol{\tilde{C}}  (\boldsymbol{I} - \boldsymbol{\overline{A}} z)^{-1} \boldsymbol{\overline{B}}\)</p>

<p>And for all $z \in \Omega_L$, we have $z^L = 1$ so that term is removed. We then pull this constant
term into a new $\boldsymbol{\tilde{C}}$. Critically, this function <strong>does not</strong> call <code class="language-plaintext highlighter-rouge">K_conv</code>,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">K_gen_inverse</span><span class="p">(</span><span class="n">Ab</span><span class="p">,</span> <span class="n">Bb</span><span class="p">,</span> <span class="n">Cb</span><span class="p">,</span> <span class="n">L</span><span class="p">):</span>
    <span class="n">I</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">eye</span><span class="p">(</span><span class="n">Ab</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
    <span class="n">Ab_L</span> <span class="o">=</span> <span class="n">matrix_power</span><span class="p">(</span><span class="n">Ab</span><span class="p">,</span> <span class="n">L</span><span class="p">)</span>
    <span class="n">Ct</span> <span class="o">=</span> <span class="n">Cb</span> <span class="o">@</span> <span class="p">(</span><span class="n">I</span> <span class="o">-</span> <span class="n">Ab_L</span><span class="p">)</span>
    <span class="k">return</span> <span class="k">lambda</span> <span class="n">z</span><span class="p">:</span> <span class="p">(</span><span class="n">Ct</span> <span class="o">@</span> <span class="n">inv</span><span class="p">(</span><span class="n">I</span> <span class="o">-</span> <span class="n">Ab</span> <span class="o">*</span> <span class="n">z</span><span class="p">)</span> <span class="o">@</span> <span class="n">Bb</span><span class="p">).</span><span class="n">reshape</span><span class="p">()</span>
</code></pre></div></div>

<p>But it does output the same values,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">test_gen_inverse</span><span class="p">(</span><span class="n">L</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">N</span><span class="o">=</span><span class="mi">4</span><span class="p">):</span>
    <span class="n">ssm</span> <span class="o">=</span> <span class="n">random_SSM</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">N</span><span class="p">)</span>
    <span class="n">b</span> <span class="o">=</span> <span class="n">K_conv</span><span class="p">(</span><span class="o">*</span><span class="n">ssm</span><span class="p">,</span> <span class="n">L</span><span class="o">=</span><span class="n">L</span><span class="p">)</span>

    <span class="n">a</span> <span class="o">=</span> <span class="n">conv_from_gen</span><span class="p">(</span><span class="n">K_gen_inverse</span><span class="p">(</span><span class="o">*</span><span class="n">ssm</span><span class="p">,</span> <span class="n">L</span><span class="o">=</span><span class="n">L</span><span class="p">),</span> <span class="n">L</span><span class="p">)</span>
    <span class="k">assert</span> <span class="n">np</span><span class="p">.</span><span class="n">isclose</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">).</span><span class="nb">all</span><span class="p">()</span>
</code></pre></div></div>

<p>In summary, Step 1 allows us to replace the matrix power with an
 inverse by utilizing a truncated generating function.
 However this inverse still needs to be calculated $L$
 times (for each of the roots of unity).</p>

<h3 id="step-2-diagonal-case">Step 2: Diagonal Case</h3>

<p>The next step to assume special <em>structure</em> on the matrix
$\boldsymbol{A}$ to avoid the inverse.  To begin, let us first
convert the equation above to use the original SSM matrices. With
some algebra you can expand the discretization and show:</p>

\[\begin{aligned}
  \boldsymbol{\tilde{C}}\left(\boldsymbol{I} - \boldsymbol{\overline{A}} \right)^{-1} \boldsymbol{\overline{B}}
  =
  \frac{2\Delta}{1+z} \boldsymbol{\tilde{C}} \left[ {2 \frac{1-z}{1+z}} - \Delta \boldsymbol{A} \right]^{-1} \boldsymbol{B}
\end{aligned}\]

<p>Now imagine $A=\boldsymbol{\Lambda}$ for a diagonal $\boldsymbol{\Lambda}$. Substituting in the discretization
formula the authors show that the generating function can be written in the following manner:</p>

<p>\(\begin{aligned}
\boldsymbol{\hat{K}}_{\boldsymbol{\Lambda}}(z) &amp; = c(z) \sum_i \cdot \frac{\tilde{C}_i B_i} {(g(z) - \Lambda_{i})} = c(z) \cdot k_{z, \boldsymbol{\Lambda}}(\boldsymbol{\tilde{C}}, \boldsymbol{B}) \\
 \end{aligned}\)
where $c$ is a constant, and $g$ is a function of $z$.</p>

<p>We have effectively replaced an  inverse with a weighted dot product.
Let’s make a small helper function to compute this weight dot product for use.
Here <a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.vectorize.html">vectorize</a>
is a decorator that let’s us broadcast this function automatically,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">partial</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">vectorize</span><span class="p">,</span> <span class="n">signature</span><span class="o">=</span><span class="s">"(c),(),(c)-&gt;()"</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">cauchy_dot</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">omega</span><span class="p">,</span> <span class="n">lambd</span><span class="p">):</span>
    <span class="k">return</span> <span class="p">(</span><span class="n">v</span> <span class="o">/</span> <span class="p">(</span><span class="n">omega</span> <span class="o">-</span> <span class="n">lambd</span><span class="p">)).</span><span class="nb">sum</span><span class="p">()</span>
</code></pre></div></div>

<p>While not important for our implementation, it is worth noting
that this is a <a href="https://en.wikipedia.org/wiki/Cauchy_matrix">Cauchy kernel</a>
and is the subject of many <a href="https://en.wikipedia.org/wiki/Fast_multipole_method">fast implementations</a>.
On a GPU though, it is efficient enough just to compute it directly.</p>

<h3 id="step-3-diagonal-plus-low-rank">Step 3: Diagonal Plus Low-Rank</h3>

<p>The final step is to relax the diagonal assumption. In addition to
the diagonal term we allow a low-rank component with
$\boldsymbol{p}, \boldsymbol{q} \in \mathbb{C}^{N\times 1}$ such that:</p>

\[\boldsymbol{A} = \boldsymbol{\Lambda} + \boldsymbol{p}  \boldsymbol{q}^*\]

<p>The <a href="https://en.wikipedia.org/wiki/Woodbury_matrix_identity">Woodbury identity</a>
tells us that the inverse of a diagonal plus rank-1 term is equal to the
inverse of the diagonal plus a rank-1 term. Or in math:</p>

\[\begin{aligned}
(\boldsymbol{\Lambda} + \boldsymbol{p}  \boldsymbol{q}^*)^{-1} &amp;= \boldsymbol{\Lambda}^{-1} - \boldsymbol{\Lambda}^{-1} \boldsymbol{p} (1 + \boldsymbol{q}^* \boldsymbol{p})^{-1} \boldsymbol{q}^* \boldsymbol{\Lambda}^{-1}
 \end{aligned}\]

<p>There is a bunch of algebra not shown. But it mostly consists of substituting this component in for A,
 applying the Woodbury identity and distributing terms. We end up with 4 terms that
 all look like Step 2 above:</p>

\[\begin{aligned}
\boldsymbol{\hat{K}}_{DPLR}(z) &amp; = c(z) [k_{z, \Lambda}(\boldsymbol{\tilde{C}}, \boldsymbol{\boldsymbol{B}}) - k_{z, \Lambda}(\boldsymbol{\tilde{C}}, \boldsymbol{\boldsymbol{p}}) (1 - k_{z, \Lambda}(\boldsymbol{q^*}, \boldsymbol{\boldsymbol{p}}) )^{-1} k_{z, \Lambda}(\boldsymbol{q^*}, \boldsymbol{\boldsymbol{B}}) ]
 \end{aligned}\]

<p>The code consists of collecting up the terms and applying 4 weighted dot products,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">K_gen_DPLR</span><span class="p">(</span><span class="n">Lambda</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">Ct</span><span class="p">,</span> <span class="n">step</span><span class="p">):</span>
    <span class="n">aterm</span> <span class="o">=</span> <span class="p">(</span><span class="n">Ct</span><span class="p">.</span><span class="n">conj</span><span class="p">().</span><span class="n">ravel</span><span class="p">(),</span> <span class="n">q</span><span class="p">.</span><span class="n">conj</span><span class="p">().</span><span class="n">ravel</span><span class="p">())</span>
    <span class="n">bterm</span> <span class="o">=</span> <span class="p">(</span><span class="n">B</span><span class="p">.</span><span class="n">ravel</span><span class="p">(),</span> <span class="n">p</span><span class="p">.</span><span class="n">ravel</span><span class="p">())</span>

    <span class="k">def</span> <span class="nf">gen</span><span class="p">(</span><span class="n">o</span><span class="p">):</span>
        <span class="n">g</span> <span class="o">=</span> <span class="p">(</span><span class="mf">2.0</span> <span class="o">/</span> <span class="n">step</span><span class="p">)</span> <span class="o">*</span> <span class="p">((</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">o</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">o</span><span class="p">))</span>
        <span class="n">c</span> <span class="o">=</span> <span class="mf">2.0</span> <span class="o">/</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">o</span><span class="p">)</span>

        <span class="k">def</span> <span class="nf">k</span><span class="p">(</span><span class="n">a</span><span class="p">):</span>
            <span class="k">return</span> <span class="n">cauchy_dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">g</span><span class="p">,</span> <span class="n">Lambda</span><span class="p">)</span>

        <span class="n">k00</span> <span class="o">=</span> <span class="n">k</span><span class="p">(</span><span class="n">aterm</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">bterm</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
        <span class="n">k01</span> <span class="o">=</span> <span class="n">k</span><span class="p">(</span><span class="n">aterm</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">bterm</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
        <span class="n">k10</span> <span class="o">=</span> <span class="n">k</span><span class="p">(</span><span class="n">aterm</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">bterm</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
        <span class="n">k11</span> <span class="o">=</span> <span class="n">k</span><span class="p">(</span><span class="n">aterm</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">bterm</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
        <span class="k">return</span> <span class="n">c</span> <span class="o">*</span> <span class="p">(</span><span class="n">k00</span> <span class="o">-</span> <span class="n">k01</span> <span class="o">*</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">k11</span><span class="p">))</span> <span class="o">*</span> <span class="n">k10</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">gen</span>
</code></pre></div></div>

<p>This is our final version of the $K$ function. Now we can check whether it worked.
First, let’s generate a random Diagonal Plus Low Rank (DPLR) matrix,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">random_DPLR</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">N</span><span class="p">):</span>
    <span class="n">l_r</span><span class="p">,</span> <span class="n">p_r</span><span class="p">,</span> <span class="n">q_r</span><span class="p">,</span> <span class="n">b_r</span><span class="p">,</span> <span class="n">c_r</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
    <span class="n">Lambda</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">l_r</span><span class="p">,</span> <span class="p">(</span><span class="n">N</span><span class="p">,))</span>
    <span class="n">p</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">p_r</span><span class="p">,</span> <span class="p">(</span><span class="n">N</span><span class="p">,))</span>
    <span class="n">q</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">q_r</span><span class="p">,</span> <span class="p">(</span><span class="n">N</span><span class="p">,))</span>
    <span class="n">B</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">b_r</span><span class="p">,</span> <span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
    <span class="n">C</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">c_r</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">N</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">Lambda</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span>
</code></pre></div></div>

<p>We can check that the DPLR method yields the same filter as computing $\boldsymbol{A}$ directly,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">test_gen_dplr</span><span class="p">(</span><span class="n">L</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">N</span><span class="o">=</span><span class="mi">4</span><span class="p">):</span>
    <span class="n">I</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">eye</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span>

    <span class="c1"># Create a DPLR A matrix and discretize
</span>    <span class="n">Lambda</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">random_DPLR</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">N</span><span class="p">)</span>
    <span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">diag</span><span class="p">(</span><span class="n">Lambda</span><span class="p">)</span> <span class="o">-</span> <span class="n">p</span><span class="p">[:,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">]</span> <span class="o">*</span> <span class="n">q</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:]</span>
    <span class="n">Ab</span><span class="p">,</span> <span class="n">Bb</span><span class="p">,</span> <span class="n">Cb</span> <span class="o">=</span> <span class="n">discretize</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">L</span><span class="p">)</span>
    <span class="n">a</span> <span class="o">=</span> <span class="n">K_conv</span><span class="p">(</span><span class="n">Ab</span><span class="p">,</span> <span class="n">Bb</span><span class="p">,</span> <span class="n">Cb</span><span class="p">,</span> <span class="n">L</span><span class="o">=</span><span class="n">L</span><span class="p">)</span>

    <span class="c1"># Compare to the DPLR generating function approach.
</span>    <span class="n">Ct</span> <span class="o">=</span> <span class="p">(</span><span class="n">I</span> <span class="o">-</span> <span class="n">matrix_power</span><span class="p">(</span><span class="n">Ab</span><span class="p">,</span> <span class="n">L</span><span class="p">)).</span><span class="n">conj</span><span class="p">().</span><span class="n">T</span> <span class="o">@</span> <span class="n">Cb</span><span class="p">.</span><span class="n">ravel</span><span class="p">()</span>
    <span class="n">b</span> <span class="o">=</span> <span class="n">conv_from_gen</span><span class="p">(</span><span class="n">K_gen_DPLR</span><span class="p">(</span><span class="n">Lambda</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">Ct</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">/</span> <span class="n">L</span><span class="p">),</span> <span class="n">L</span><span class="p">)</span>
    <span class="k">assert</span> <span class="n">np</span><span class="p">.</span><span class="n">isclose</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">).</span><span class="nb">all</span><span class="p">()</span>
</code></pre></div></div>

<h3 id="turning-hippo-to-dplr">Turning HiPPO to DPLR</h3>

<p>This approach applies to DPLR matrices, but remember we would like it to also apply to the HiPPO matrix.
 While not DPLR in its current form, the HiPPO matrix <em>does have special structure</em>. It is
 <a href="https://en.wikipedia.org/wiki/Normal_matrix">Normal</a> Plus Low-Rank (NPLR). The paper argues that
this is just as good as DPLR for the purposes of learning an SSM network.</p>

<blockquote>
  <p>The S4 techniques can apply to any matrix $\boldsymbol{A}$ that can be decomposed as <em>Normal Plus Low-Rank (NPLR)</em>.
\(\boldsymbol{A} = \boldsymbol{V} \boldsymbol{\Lambda} \boldsymbol{V}^* - \boldsymbol{p} \boldsymbol{q}^\top = \boldsymbol{V} \left( \boldsymbol{\Lambda} - \boldsymbol{V}^* \boldsymbol{p} (\boldsymbol{V}^*\boldsymbol{q})^* \right) \boldsymbol{V}^*\)
for <a href="https://en.wikipedia.org/wiki/Unitary_matrix">unitary</a> $\boldsymbol{V} \in \mathbb{C}^{N \times N}$, diagonal $\boldsymbol{\Lambda}$, and low-rank factorization $\boldsymbol{p}, \boldsymbol{q} \in \mathbb{R}^{N \times r}$.  An NPLR SSM is therefore <a href="https://en.wikipedia.org/wiki/Unitary_matrix">unitarily</a> equivalent to some DPLR matrix.</p>
</blockquote>

<p>For S4, we need to work with a HiPPO matrix for $\boldsymbol{A}$. This requires extracting
 $\boldsymbol{\Lambda}$ from this decomposition. The appendix of the paper shows this
 by getting it into  a <a href="https://en.wikipedia.org/wiki/Skew-symmetric_matrix">skew-symmetric</a>
 (normal) + low-rank form. We can use this math to get out the DPLR terms,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">make_NPLR_HiPPO</span><span class="p">(</span><span class="n">N</span><span class="p">):</span>
    <span class="c1"># Make -HiPPO
</span>    <span class="n">nhippo</span> <span class="o">=</span> <span class="o">-</span><span class="n">make_HiPPO</span><span class="p">(</span><span class="n">N</span><span class="p">)</span>

    <span class="c1"># Add in a rank 1 term. Makes it Normal.
</span>    <span class="n">p</span> <span class="o">=</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">N</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="mf">1.0</span><span class="p">)</span>
    <span class="n">q</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">p</span>
    <span class="n">S</span> <span class="o">=</span> <span class="n">nhippo</span> <span class="o">+</span> <span class="n">p</span><span class="p">[:,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">]</span> <span class="o">*</span> <span class="n">q</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:]</span>

    <span class="c1"># Diagonalize to S to V \Lambda V^*
</span>    <span class="n">Lambda</span><span class="p">,</span> <span class="n">V</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">jit</span><span class="p">(</span><span class="n">eig</span><span class="p">,</span> <span class="n">backend</span><span class="o">=</span><span class="s">"cpu"</span><span class="p">)(</span><span class="n">S</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">nhippo</span><span class="p">,</span> <span class="n">Lambda</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">V</span>
</code></pre></div></div>

<p>Final sanity check just to make sure those identities hold,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">test_nplr</span><span class="p">(</span><span class="n">N</span><span class="o">=</span><span class="mi">8</span><span class="p">):</span>
    <span class="n">A2</span><span class="p">,</span> <span class="n">Lambda</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">V</span> <span class="o">=</span> <span class="n">make_NPLR_HiPPO</span><span class="p">(</span><span class="n">N</span><span class="p">)</span>
    <span class="n">p</span><span class="p">,</span> <span class="n">q</span> <span class="o">=</span> <span class="n">p</span><span class="p">[:,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">],</span> <span class="n">q</span><span class="p">[:,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">]</span>
    <span class="n">Lambda</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">diag</span><span class="p">(</span><span class="n">Lambda</span><span class="p">)</span>
    <span class="n">Vc</span> <span class="o">=</span> <span class="n">V</span><span class="p">.</span><span class="n">conj</span><span class="p">().</span><span class="n">T</span>
    <span class="n">A3</span> <span class="o">=</span> <span class="n">V</span> <span class="o">@</span> <span class="p">(</span><span class="n">Lambda</span> <span class="o">-</span> <span class="p">(</span><span class="n">Vc</span> <span class="o">@</span> <span class="n">p</span><span class="p">)</span> <span class="o">@</span> <span class="p">(</span><span class="n">Vc</span> <span class="o">@</span> <span class="n">q</span><span class="p">.</span><span class="n">conj</span><span class="p">()).</span><span class="n">conj</span><span class="p">().</span><span class="n">T</span><span class="p">)</span> <span class="o">@</span> <span class="n">Vc</span>
    <span class="n">A4</span> <span class="o">=</span> <span class="n">V</span> <span class="o">@</span> <span class="n">Lambda</span> <span class="o">@</span> <span class="n">Vc</span> <span class="o">-</span> <span class="p">(</span><span class="n">p</span> <span class="o">@</span> <span class="n">q</span><span class="p">.</span><span class="n">T</span><span class="p">)</span>
    <span class="k">assert</span> <span class="n">np</span><span class="p">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">A2</span><span class="p">,</span> <span class="n">A3</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">)</span>
    <span class="k">assert</span> <span class="n">np</span><span class="p">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">A2</span><span class="p">,</span> <span class="n">A4</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">)</span>
</code></pre></div></div>

<h2 id="part-3-s4-in-practice">Part 3: S4 in Practice</h2>

<p>That was a lot of work, but now the actual model is concise. In fact
we are only using four functions:</p>

<ol>
  <li><code class="language-plaintext highlighter-rouge">discretize</code> → Convert SSM to discrete form.</li>
  <li><code class="language-plaintext highlighter-rouge">K_gen_DPLR</code> → Truncated generating function when $\boldsymbol{A}$ is DPLR (S4-part)</li>
  <li><code class="language-plaintext highlighter-rouge">conv_from_gen</code> → Convert generating function to filter</li>
  <li><code class="language-plaintext highlighter-rouge">non_circular_convolution</code> → Run convolution</li>
</ol>

<p>A full S4 Layer is very similar to the simple SSM layer above. The
 only difference is in the the computation of $\boldsymbol{K}$.
 Additionally instead of learning $\boldsymbol{C}$, we learn
 $\boldsymbol{\tilde{C}}$ so we avoid computing powers of
 $\boldsymbol{A}$. Note as well that in the original paper $\Lambda, p, q$ are
 also learned. However, in this post, we leave them fixed for simplicity.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">S4Layer</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="n">A</span><span class="p">:</span> <span class="n">np</span><span class="p">.</span><span class="n">DeviceArray</span>
    <span class="n">p</span><span class="p">:</span> <span class="n">np</span><span class="p">.</span><span class="n">DeviceArray</span>
    <span class="n">q</span><span class="p">:</span> <span class="n">np</span><span class="p">.</span><span class="n">DeviceArray</span>
    <span class="n">Lambda</span><span class="p">:</span> <span class="n">np</span><span class="p">.</span><span class="n">DeviceArray</span>

    <span class="n">N</span><span class="p">:</span> <span class="nb">int</span>
    <span class="n">l_max</span><span class="p">:</span> <span class="nb">int</span>

    <span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">B</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">param</span><span class="p">(</span><span class="s">"B"</span><span class="p">,</span> <span class="n">lecun_normal</span><span class="p">(),</span> <span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">N</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">D</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">param</span><span class="p">(</span><span class="s">"D"</span><span class="p">,</span> <span class="n">nn</span><span class="p">.</span><span class="n">initializers</span><span class="p">.</span><span class="n">ones</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">Ct</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">param</span><span class="p">(</span>
            <span class="s">"Ct"</span><span class="p">,</span> <span class="n">lecun_normal</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">jax</span><span class="p">.</span><span class="n">numpy</span><span class="p">.</span><span class="n">complex64</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">N</span><span class="p">)</span>
        <span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">log_step</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">param</span><span class="p">(</span><span class="s">"log_step"</span><span class="p">,</span> <span class="n">log_step_initializer</span><span class="p">(),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,))</span>
        <span class="n">step</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">log_step</span><span class="p">)</span>

        <span class="n">K_gen</span> <span class="o">=</span> <span class="n">K_gen_DPLR</span><span class="p">(</span>
            <span class="bp">self</span><span class="p">.</span><span class="n">Lambda</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">p</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">q</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">B</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">Ct</span><span class="p">,</span> <span class="n">step</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
        <span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">K</span> <span class="o">=</span> <span class="n">conv_from_gen</span><span class="p">(</span><span class="n">K_gen</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">l_max</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">u</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">non_circular_convolution</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">K</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">D</span> <span class="o">*</span> <span class="n">u</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">S4Layer</span> <span class="o">=</span> <span class="n">cloneLayer</span><span class="p">(</span><span class="n">S4Layer</span><span class="p">)</span>
</code></pre></div></div>

<p>We initialize the model by computing a DPLR initializer similar to HiPPO,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">S4LayerInit</span><span class="p">(</span><span class="n">N</span><span class="p">):</span>
    <span class="n">_</span><span class="p">,</span> <span class="n">Lambda</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">V</span> <span class="o">=</span> <span class="n">make_NPLR_HiPPO</span><span class="p">(</span><span class="n">N</span><span class="p">)</span>
    <span class="n">Vc</span> <span class="o">=</span> <span class="n">V</span><span class="p">.</span><span class="n">conj</span><span class="p">().</span><span class="n">T</span>
    <span class="n">p</span> <span class="o">=</span> <span class="n">Vc</span> <span class="o">@</span> <span class="n">p</span>
    <span class="n">q</span> <span class="o">=</span> <span class="n">Vc</span> <span class="o">@</span> <span class="n">q</span><span class="p">.</span><span class="n">conj</span><span class="p">()</span>
    <span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">diag</span><span class="p">(</span><span class="n">Lambda</span><span class="p">)</span> <span class="o">-</span> <span class="n">p</span><span class="p">[:,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">]</span> <span class="o">@</span> <span class="n">q</span><span class="p">[:,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">].</span><span class="n">conj</span><span class="p">().</span><span class="n">T</span>
    <span class="k">return</span> <span class="n">partial</span><span class="p">(</span><span class="n">S4Layer</span><span class="p">,</span> <span class="n">N</span><span class="o">=</span><span class="n">N</span><span class="p">,</span> <span class="n">A</span><span class="o">=</span><span class="n">A</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="o">=</span><span class="n">q</span><span class="p">,</span> <span class="n">Lambda</span><span class="o">=</span><span class="n">Lambda</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="experiments">Experiments</h3>

<p>Now that we have the model, we can try it out on some MNIST experiments.
For these experiments we linearize MNIST and just treat each image as a sequence of
pixels.</p>

<p>The first experiments we ran were on MNIST classification. While
not in theory a hard problem, treating MNIST as a linear sequence
classification task is a bit strange. However in practice, the model
with $H=256$ and four layers seems to get up near 99% right away.</p>

<p>A more visually interesting task is generating MNIST digits, by predicting entire
sequences of pixels! Here, we simply feed in a sequence of pixels into the model and have it
predict the next one like language modeling. With a little
tweaking, we are able to get the model to an NLL of 0.52 on this
task with size 512 and 6 layers (~2m parameters).</p>

<p>The metric usually used for this task is <em><a href="https://paperswithcode.com/sota/image-generation-on-mnist">bits per
dimension</a></em> which is
NLL in base 2 for MNIST. A score of 0.52 is ~0.76 BPD which is near PixelCNN++.</p>

<p>We can sample from the model using the CNN implementation. Ideally we would use the
RNN form, but that would require a bit more plumbing,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">sample_mnist</span><span class="p">():</span>
    <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
    <span class="kn">from</span> <span class="nn">flax.training</span> <span class="kn">import</span> <span class="n">checkpoints</span>

    <span class="n">model</span> <span class="o">=</span> <span class="n">S4LayerInit</span><span class="p">(</span><span class="n">N</span><span class="o">=</span><span class="mi">64</span><span class="p">)</span>
    <span class="n">model</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span>
        <span class="n">BatchSeqModel</span><span class="p">,</span>
        <span class="n">layer</span><span class="o">=</span><span class="n">model</span><span class="p">,</span>
        <span class="n">d_output</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span>
        <span class="n">d_model</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span>
        <span class="n">n_layers</span><span class="o">=</span><span class="mi">6</span><span class="p">,</span>
        <span class="n">l_max</span><span class="o">=</span><span class="mi">783</span><span class="p">,</span>
    <span class="p">)</span>
    <span class="n">rng</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    <span class="n">state</span> <span class="o">=</span> <span class="n">checkpoints</span><span class="p">.</span><span class="n">restore_checkpoint</span><span class="p">(</span><span class="s">"models/best_84"</span><span class="p">,</span> <span class="bp">None</span><span class="p">)</span>
    <span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">training</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
    <span class="n">start</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>

    <span class="k">def</span> <span class="nf">loop</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">cur</span><span class="p">):</span>
        <span class="n">cur</span><span class="p">,</span> <span class="n">rng</span> <span class="o">=</span> <span class="n">cur</span>
        <span class="n">r</span><span class="p">,</span> <span class="n">rng</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng</span><span class="p">)</span>
        <span class="n">out</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="nb">apply</span><span class="p">({</span><span class="s">"params"</span><span class="p">:</span> <span class="n">state</span><span class="p">[</span><span class="s">"params"</span><span class="p">]},</span> <span class="n">cur</span><span class="p">[:,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
        <span class="n">p</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">categorical</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">out</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">i</span><span class="p">])</span>
        <span class="n">cur</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">ops</span><span class="p">.</span><span class="n">index_update</span><span class="p">(</span><span class="n">cur</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">p</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">cur</span><span class="p">,</span> <span class="n">rng</span>

    <span class="n">out</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">lax</span><span class="p">.</span><span class="n">fori_loop</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">783</span><span class="p">,</span> <span class="n">jax</span><span class="p">.</span><span class="n">jit</span><span class="p">(</span><span class="n">loop</span><span class="p">),</span> <span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="n">rng</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">))</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">savefig</span><span class="p">(</span><span class="s">"sample.png"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">sample_mnist</span><span class="p">()</span>
</code></pre></div></div>

<p><img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/sample.png" width="100%" /></p>

<p>We can also do prefix-samples – given the first 300 pixels, try to complete the image.
S4 is on the left, true on the right.</p>

<p><img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/im12.png" width="45%" />
<img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/im13.png" width="45%" />
<img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/im14.png" width="45%" />
<img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/im15.png" width="45%" />
<img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/im16.png" width="45%" />
<img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/im17.png" width="45%" />
<img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/im18.png" width="45%" />
<img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/im19.png" width="45%" /></p>

<p>Next we tried training a model to generate drawings. For this we
used the <a href="https://github.com/googlecreativelab/quickdraw-dataset">QuickDraw
dataset</a>.
The dataset includes a version of the dataset downsampled to MNIST
size so we can use roughly the same model as above. The dataset
is much larger though (5M images) and more complex. We only trained
for 1 epoch with a $H=256$, 4 layer model. Still, the approach was
able to generate relatively coherent completions. These are prefix
samples with 500 pixels given.</p>

<p><img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/quickdraw/im1.png" width="45%" />
<img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/quickdraw/im2.png" width="45%" />
<img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/quickdraw/im3.png" width="45%" />
<img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/quickdraw/im4.png" width="45%" />
<img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/quickdraw/im5.png" width="45%" />
<img src="https://iclr.iro.umontreal.ca/4de73081-c045-4ef8-8a76-5783bb8f6e9c_1642114847/public/images/2021-12-01-annotated-s4/images/quickdraw/im6.png" width="45%" /></p>

<p>Our full code base contains
more examples and infrastructure for training models for generations and
classification.</p>

<h2 id="conclusion">Conclusion</h2>

<p>Putting together this post inspired lots of thoughts about future
work in this area. One obvious conclusion is that long-range
models have all sorts of future applications from acoustic modeling to
genomic sequences to trajectories (not to mention our shared area of
NLP). Another is some surprise that linear models can be so effective
here, while also opening up a range of efficient techniques.
Finally from a practical level, the transformations in JAX
make it really nice to implement complex models like this
in a very concise way (~200 LoC), with similar efficiency and performance!</p>

<p>/ Cheers</p>
