Clock driven: Neurons
=======================================
Author: `fangwei123456 <https://github.com/fangwei123456>`_

Translator: `YeYumin <https://github.com/YEYUMIN>`_

This tutorial focuses on :class:`spikingjelly.activation_based.neuron` and introduces spiking neurons and clock-driven
simulation methods.

Spiking Nneuron Model
-----------------------------------------------
In ``spikingjelly``, we define the neuron which can only output spikes, i.e. 0 or 1, as a "spiking neuron".
Networks that use spiking neurons are called Spiking Neural Networks (SNNs).
:class:`spikingjelly.activation_based.neuron` defines various common spiking neuron models.
We take :class:`spikingjelly.activation_based.neuron.LIFNode` as an example to introduce spiking neurons.

First, we need to import the relevant modules:

.. code-block:: python

    import torch
    import torch.nn as nn
    import numpy as np
    from spikingjelly.activation_based import neuron
    from spikingjelly import visualizing
    from matplotlib import pyplot as plt

And then we create a new LIF neurons layer:

.. code-block:: python

    lif = neuron.LIFNode()

The LIF neurons layer has some parameters, which are explained in detail in the API documentation:

    - **tau** -- membrane time constant

    - **v_threshold** -- the threshold voltage of the neuron

    - **v_reset** -- the reset voltage of the neuron. If it is not ``None``, when the neuron releases a spike, the voltage will be reset to ``v_reset``; if it is set to ``None``, the voltage will be subtracted from ``v_threshold``

    - **surrogate_function** -- the surrogate function used to calculate the gradient of the spike function during back propagation

The ``surrogate_function`` behaves exactly the same as the step function during forward propagation,
and we will introduce its working principle for back propagation later. We can just ignore it now.

You may be curious about the number of neurons in this layer. For most neurons layers in :class:`spikingjelly.activation_based.neuron`,
the number of neurons is automatically determined according to the ``shape`` of the received input after initialization or re-initialization by calling the ``reset()`` function.

Similar to neurons in RNN, spiking neurons are also stateful (they have memory).
The state variable of a spiking neuron is generally its membrane potential :math:`V_{t}`.
Therefore, neurons in :class:`spikingjelly.activation_based.neuron` have state variable ``v``.
We can print the membrane potential of the newly created LIF neurons layer:

.. code-block:: python

    print(lif.v)
    # 0.0

We can find that ``lif.v`` is now ``0.0`` because we haven't given it any input yet.
We give several different inputs and observe the ``shape`` of ``lif.v``. We can find that it is consistent with the
numel of inputs:

.. code-block:: python

    x = torch.rand(size=[2, 3])
    lif(x)
    print('x.shape', x.shape, 'lif.v.shape', lif.v.shape)
    # x.shape torch.Size([2, 3]) lif.v.shape torch.Size([2, 3])
    lif.reset()

    x = torch.rand(size=[4, 5, 6])
    lif(x)
    print('x.shape', x.shape, 'lif.v.shape', lif.v.shape)
    # x.shape torch.Size([4, 5, 6]) lif.v.shape torch.Size([4, 5, 6])
    lif.reset()

What is the relationship between :math:`V_{t}` and input :math:`X_{t}`? In the spiking neuron,
it not only depends on the input :math:`X_{t}` at time-step ``t``,
but also on its membrane potential :math:`V_{t-1}` at the last time-step ``t-1``.

We often use the sub-threshold (when the membrane potential does not exceed the threshold potential ``V_{threshold}``) neuronal dynamics equation :math:`\frac{\mathrm{d}V(t)}{\mathrm{d}t} = f(V(t), X(t))` to describe the continuous-time
spiking neuron. For example. For LIF neurons, the equation is:

.. math::
    \tau_{m} \frac{\mathrm{d}V(t)}{\mathrm{d}t} = -(V(t) - V_{reset}) + X(t)

where :math:`\tau_{m}` is the membrane time constant and :math:`V_{reset}` is the reset potential. For such a differential equation, :math:`X(t)` is not a constant and it is difficult to obtain a explicit analytical solution.

The neurons in :class:`spikingjelly.activation_based.neuron` use discrete difference equations to approximate continuous differential equations.
From the perspective of the discrete equation, the charging equation of the LIF neuron is:

.. math::
    \tau_{m} (V_{t} - V_{t-1}) = -(V_{t-1} - V_{reset}) + X_{t}

The expression of :math:`V_{t}` can be obtained as

.. math::
    V_{t} = f(V_{t-1}, X_{t}) = V_{t-1} + \frac{1}{\tau_{m}}(-(V_{t - 1} - V_{reset}) + X_{t})

The corresponding code can be found in :class:`spikingjelly.activation_based.neuron.LIFNode.neuronal_charge`:

.. code-block:: python

    def neuronal_charge(self, dv: torch.Tensor):
        if self.v_reset is None:
            self.v += (x - self.v) / self.tau

        else:
            if isinstance(self.v_reset, float) and self.v_reset == 0.:
                self.v += (x - self.v) / self.tau
            else:
                self.v += (x - (self.v - self.v_reset)) / self.tau

Different neurons have different charging equations. However, when the membrane potential exceeds the threshold potential,
the release of spike and the reset of the membrane potential are the same for all kinds of neurons. Therefore,
they all inherit from :class:`spikingjelly.activation_based.neuron.BaseNode` and share the same discharge and reset equations. The codes of neuronal fire can be found at :class:`spikingjelly.activation_based.neuron.BaseNode.neuronal_fire`:

.. code-block:: python

    def neuronal_fire(self):
        self.spike = self.surrogate_function(self.v - self.v_threshold)

``surrogate_function()`` is a heaviside step function during forward propagation. When input is greater than or equal
to 0, it will return 1, otherwise it will return 0. We regard this kind of ``tensor`` whose elements are only 0 or 1 as spikes.

The release of spikes consumes the previously accumulated electric charge of the neuron, so there will be an
instantaneous decrease in the membrane potential, which is the neuronal reset. In SNNs, there are
two ways to realize neuronal reset:

#. Hard method: After releasing a spike, the membrane potential is directly set to the reset potential :math:`V = V_{reset}`

#. Soft method: After releasing a spike, the membrane potential subtracts the threshold voltage :math:`V = V - V_{threshold}`

It can be found that for neurons using the soft method, there is no need to reset the voltage :math:`V_{reset}`.
For the neurons in :class:`spikingjelly.activation_based.neuron`, when ``v_reset`` is set to the a float value (e.g., the default value is ``1.0``), the neuron uses the hard reset; if ``v_reset`` is set to ``None``, the soft reset will be used.
We can find the corresponding codes in :class:`spikingjelly.activation_based.neuron.BaseNode.neuronal_fire.neuronal_reset`:

.. code-block:: python

    def neuronal_reset(self):
        # ...
        if self.v_reset is None:
            self.v = self.v - spike * self.v_threshold
        else:
            self.v = (1 - spike) * self.v + spike * self.v_reset


Three Equations to Describe Discrete Spiking Neurons
--------------------------------------------------------------
We can use the three discrete equations: neuronal charge, neuronal fire, and neuronal reset to describe all kinds of discrete spiking neurons. The neuronal charge and fire equations are:

.. math::
    H_{t} & = f(V_{t-1}, X_{t}) \\
    S_{t} & = g(H_{t} - V_{threshold}) = \Theta(H_{t} - V_{threshold})

where :math:`\Theta(x)` is the ``surrogate_function()`` in the parameters, which is a heaviside step function:

.. math::
    \Theta(x) =
    \begin{cases}
    1, & x \geq 0 \\
    0, & x < 0
    \end{cases}

The hard reset is:

.. math::
    V_{t} = H_{t} \cdot (1 - S_{t}) + V_{reset} \cdot S_{t}

The soft reset is:

.. math::
    V_{t} = H_{t} - V_{threshold} \cdot S_{t}

where :math:`V_{t}` is the membrane potential of the neuron, :math:`X_{t}` is the external input, such as voltage increment.
To avoid confusion, we use :math:`H_{t}` to represent the membrane potential after neuronal charge but before neuronal fire,
:math:`V_{t}` is the membrane potential after the neuronal fire, :math:`f(V(t-1), X(t))` is the neuronal charge function.
The difference between neurons is the neuronal charge.

Clock-driven Simulation
---------------------------

:class:`spikingjelly.activation_based` uses a clock-driven approach to simulate SNN.

Next, we will stimulate the neuron and check its membrane potential and output spikes.

Now let us give constant input to the LIF neurons layer and plot the membrane potential and output spikes:

.. code-block:: python

    lif.reset()
    x = torch.as_tensor([2.])
    T = 150
    s_list = []
    v_list = []
    for t in range(T):
        s_list.append(lif(x))
        v_list.append(lif.v)

    visualizing.plot_one_neuron_v_s(np.asarray(v_list), np.asarray(s_list), v_threshold=lif.v_threshold, v_reset=lif.v_reset,
                                    dpi=200)
    plt.show()

The input is with ``shape=[1]``, and this LIF neurons layer has only 1 neuron. Its membrane potential and output spikes change with time-step as follows:

.. image:: ../_static/tutorials/activation_based/0_neuron/0.*
    :width: 100%

We reset the neurons layer and give an input with ``shape=[32]`` to see the membrane potential and output spikes of these 32 neurons:

.. code-block:: python

    lif.reset()
    x = torch.rand(size=[32]) * 4
    T = 50
    s_list = []
    v_list = []
    for t in range(T):
        s_list.append(lif(x).unsqueeze(0))
        v_list.append(lif.v.unsqueeze(0))

    s_list = torch.cat(s_list)
    v_list = torch.cat(v_list)

    visualizing.plot_2d_heatmap(array=np.asarray(v_list), title='Membrane Potentials', xlabel='Simulating Step',
                                ylabel='Neuron Index', int_x_ticks=True, x_max=T, dpi=200)
    visualizing.plot_1d_spikes(spikes=np.asarray(s_list), title='Membrane Potentials', xlabel='Simulating Step',
                               ylabel='Neuron Index', dpi=200)
    plt.show()

The results are as follows:

.. image:: ../_static/tutorials/activation_based/0_neuron/1.*
    :width: 100%

.. image:: ../_static/tutorials/activation_based/0_neuron/2.*
    :width: 100%