STDP Learning
=======================================
Author: `fangwei123456 <https://github.com/fangwei123456>`_

Researchers of SNNs are always interested in biological learning rules. In SpkingJelly, STDP(Spike Timing Dependent Plasticity) \
is also provided and can be applied to convolutional or linear layers.

STDP(Spike Timing Dependent Plasticity)
-----------------------------------------------------

STDP(Spike Timing Dependent Plasticity) is proposed by [#STDP]_, which is a synaptic plasticity rule found in biological \
neural system. The experiments in the biological neural systems find that the weight of synapse is influenced by the firing time of spikes \
of the pre and post neuron. More specific, STDP can be formulated as:

If the pre neuron fires early and the post neuron fires later, then the weight will increase;
If the pre neuron fires later while the post neuron fires early, then the weight will decrease.

The curve [#STDP_figure]_ that fits the experiments data is as follows:

.. image:: ../_static/tutorials/activation_based/stdp/stdp.*
    :width: 100%

We can use the following equation to describe STDP:

.. math::

    \begin{align}
    \begin{split}
    \Delta w_{ij} =
    \begin{cases}
        A\exp(\frac{-|t_{i}-t_{j}|}{\tau_{+}}) , t_{i} \leq t_{j}, A > 0\\
        B\exp(\frac{-|t_{i}-t_{j}|}{\tau_{-}}) , t_{i} > t_{j}, B < 0
    \end{cases}
    \end{split}
    \end{align}

where :math:`A, B` are the maximum of weight variation, and :math:`\tau_{+}, \tau_{-}` are time constants.

However, the above equation is seldom used in practicals because it needs to record all firing times of pre and post neurons.\
The trace method [#Trace]_  is a more popular method to implement STDP.

For the pre neuron :math:`i` and the post neuron :math:`j`, we use the traces :math:`tr_{pre}[i]` and :math:`tr_{post}[j]` to track their firing. The update of \
traces are similar to the LIF neuron:

.. math::

    tr_{pre}[i][t] = tr_{pre}[i][t] -\frac{tr_{pre}[i][t-1]}{\tau_{pre}} + s[i][t]

    tr_{post}[j][t] = tr_{pre}[i][t] -\frac{tr_{post}[j][t-1]}{\tau_{post}} + s[j][t]

where :math:`\tau_{pre}, \tau_{post}` are time constants of the pre and post neuron. :math:`s[i][t], s[j][t]` are the \
spikes at time-step :math:`t` of the pre neuron :math:`i` and the post neuron :math:`j`, which can only be 0 or 1.

The update of weight is:

.. math::

    \Delta W[i][j][t] = F_{post}(w[i][j][t]) \cdot tr_{i}[t] \cdot s[j][t] - F_{pre}(w[i][j][t]) \cdot tr_{j}[t] \cdot s[i][t]

where :math:`F_{pre}, F_{post}` are functions that control how weight changes.

STDP Learner
-----------------------------------------------------
:class:`spikingjelly.activation_based.learning.STDPLearner` can apply STDP learning on convolutional or linear layers. \
Please read the api doc first to learn how to use it.

Now let us use ``STDPLearner`` to build the simplest ``1x1`` SNN with only one pre and one post neuron. \
And we set the weight as ``0.4``:

.. code-block:: python

    import torch
    import torch.nn as nn
    from spikingjelly.activation_based import neuron, layer, learning
    from matplotlib import pyplot as plt
    torch.manual_seed(0)

    def f_weight(x):
        return torch.clamp(x, -1, 1.)

    tau_pre = 2.
    tau_post = 2.
    T = 128
    N = 1
    lr = 0.01
    net = nn.Sequential(
        layer.Linear(1, 1, bias=False),
        neuron.IFNode()
    )
    nn.init.constant_(net[0].weight.data, 0.4)

``STDPLearner`` can add the negative weight variation ``- delta_w * scale`` on the gradient of weight, which makes it compatible with deep learning methods. We can use \
the optimizer, learning rate scheduler with ``STDPLearner`` together. 

In this example, we use the simplest parameter update method:

.. math::

    W = W - lr \cdot \nabla W

where :math:`\nabla W` is ``- delta_w * scale``. Thus, the optimizer will apply \
``weight.data = weight.data - lr * weight.grad = weight.data + lr * delta_w * scale``.

We can implement the above parameter update method by the plain :class:`torch.optim.SGD` with ``momentum=0.``:

.. code-block:: python

    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.)

Then we create the input spikes and set ``STDPLearner``:

.. code-block:: python

    in_spike = (torch.rand([T, N, 1]) > 0.7).float()
    stdp_learner = learning.STDPLearner(step_mode='s', synapse=net[0], sn=net[1], tau_pre=tau_pre, tau_post=tau_post,
                                        f_pre=f_weight, f_post=f_weight)

Then we send data to the network. Note that to plot the figure, we will ``squeeze()`` the data, which reshape them from ``shape = [T, N, 1]`` \
to ``shape = [T]``:

.. code-block:: python

    out_spike = []
    trace_pre = []
    trace_post = []
    weight = []
    with torch.no_grad():
        for t in range(T):
            optimizer.zero_grad()
            out_spike.append(net(in_spike[t]).squeeze())
            stdp_learner.step(on_grad=True)  # add ``- delta_w * scale`` on grad
            optimizer.step()
            weight.append(net[0].weight.data.clone().squeeze())
            trace_pre.append(stdp_learner.trace_pre.squeeze())
            trace_post.append(stdp_learner.trace_post.squeeze())

    in_spike = in_spike.squeeze()
    out_spike = torch.stack(out_spike)
    trace_pre = torch.stack(trace_pre)
    trace_post = torch.stack(trace_post)
    weight = torch.stack(weight)

The complete codes are available at ``spikingjelly/activation_based/examples/stdp_trace.py``:

Let us plot ``in_spike, out_spike, trace_pre, trace_post, weight``:

.. image:: ../_static/tutorials/activation_based/stdp/stdp_trace.*
    :width: 100%

This figure is similar to Fig.3 in [#Trace]_  (note that they use `j` as the pre neuron and `i` as the post neuron, while we use the opposite symbol):

.. image:: ../_static/tutorials/activation_based/stdp/trace_paper_fig3.*
    :width: 100%


Combine STDP Learning with Gradient Descent
-----------------------------------------------------
A widely used method with STDP is using gradient descent and STDP to train different layers in an SNN. \
With ``STDPLearner``, we can combine STDP learning with gradient descent easily.

Our goal is to build a deep SNN, train convolutional layers with STDP, and train linear layers with gradient descent. \
First, let us define the hyper-parameters:


.. code-block:: python

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.optim import SGD, Adam
    from spikingjelly.activation_based import learning, layer, neuron, functional

    T = 8
    N = 2
    C = 3
    H = 32
    W = 32
    lr = 0.1
    tau_pre = 2.
    tau_post = 100.
    step_mode = 'm'

Here we use the input with ``shape = [T, N, C, H, W] = [8, 2, 3, 32, 32]``.

Then we define the weight function and the SNN. Here we build a convolutional SNN with a multi-step mode:

.. code-block:: python

    def f_weight(x):
        return torch.clamp(x, -1, 1.)


    net = nn.Sequential(
        layer.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
        neuron.IFNode(),
        layer.MaxPool2d(2, 2),
        layer.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
        neuron.IFNode(),
        layer.MaxPool2d(2, 2),
        layer.Flatten(),
        layer.Linear(16 * 8 * 8, 64, bias=False),
        neuron.IFNode(),
        layer.Linear(64, 10, bias=False),
        neuron.IFNode(),
    )

    functional.set_step_mode(net, step_mode)

We want to use STDP to train ``layer.Conv2d`` while other layers are to be trained with gradient descent. \
We use ``instances_stdp`` as the layers which are trained by STDP:

.. code-block:: python

    instances_stdp = (layer.Conv2d, )

We create an STDP learner for each layer in the SNN with the instance in ``instances_stdp``:

.. code-block:: python

    stdp_learners = []

    for i in range(net.__len__()):
        if isinstance(net[i], instances_stdp):
            stdp_learners.append(
                learning.STDPLearner(step_mode=step_mode, synapse=net[i], sn=net[i+1], tau_pre=tau_pre, tau_post=tau_post,
                                    f_pre=f_weight, f_post=f_weight)
            )

Now we split parameters into two groups. The parameters from layers whose instances are in or not in ``instances_stdp`` \
will be set to two optimizers. Here we use ``Adam`` to optimize the parameters which are trained by gradient descent, and ``SGD`` \
to optimize the parameters which are trained by STDP:

.. code-block:: python

    params_stdp = []
    for m in net.modules():
        if isinstance(m, instances_stdp):
            for p in m.parameters():
                params_stdp.append(p)
                
    params_stdp_set = set(params_stdp)
    params_gradient_descent = []
    for p in net.parameters():
        if p not in params_stdp_set:
            params_gradient_descent.append(p)

    optimizer_gd = Adam(params_gradient_descent, lr=lr)
    optimizer_stdp = SGD(params_stdp, lr=lr, momentum=0.)

When we train the SNN in actual tasks, e.g., classifying CIFAR-10, we get samples from the dataset. But here we only want to \
implement an example. Hence, we create the samples manually:

.. code-block:: python

    x_seq = (torch.rand([T, N, C, H, W]) > 0.5).float()
    target = torch.randint(low=0, high=10, size=[N])

Then we will use the two optimizers to update the parameters. Note that the following codes are different from the plain \
gradient descent we use before.


First, let us clear all gradients, do a forward, calculate the loss and do a backward:

.. code-block:: python

    optimizer_gd.zero_grad()
    optimizer_stdp.zero_grad()
    y = net(x_seq).mean(0)
    loss = F.cross_entropy(y, target)
    loss.backward()

Note that even though ``optimizer_gd`` will only update parameters in ``params_gradient_descent``, ``loss.backward()`` will \
calculate and set ``.grad`` to all parameters including those we want to calculate the weight variation (implemented by on ``.grad``) by STDP.

Thus, we need to clear the gradients of ``params_stdp``:

.. code-block:: python

    optimizer_stdp.zero_grad()


Then we need to use ``STDPLearner`` to get "gradients", and use two optimizers to update all parameters:

.. code-block:: python

    for i in range(stdp_learners.__len__()):
        stdp_learners[i].step(on_grad=True)

    optimizer_gd.step()
    optimizer_stdp.step()


All the learners ( ``STDPLearner`` , for instance) inherit from ``MemoryModule``. \
Hence, they have internal memories ( ``trace_pre, trace_post`` for ``STDPLearner`` ). \
In addition, the monitors inside the learners record the firing histories of the pre-synaptic and post-synaptic neurons; these histories \
may also be considered as internal memories of the learners. We should call the ``reset()`` method to clear the internal memory promptly \
so as to avoid the nonstop growing of memory consumption. We suggest resetting the learners together with the network after each batch:

.. code-block:: python

    functional.reset_net(net)
    for i in range(stdp_learners.__len__()):
        stdp_learners[i].reset()


The complete codes are as follows:

.. code-block:: python

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.optim import SGD, Adam
    from spikingjelly.activation_based import learning, layer, neuron, functional

    T = 8
    N = 2
    C = 3
    H = 32
    W = 32
    lr = 0.1
    tau_pre = 2.
    tau_post = 100.
    step_mode = 'm'

    def f_weight(x):
        return torch.clamp(x, -1, 1.)


    net = nn.Sequential(
        layer.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
        neuron.IFNode(),
        layer.MaxPool2d(2, 2),
        layer.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
        neuron.IFNode(),
        layer.MaxPool2d(2, 2),
        layer.Flatten(),
        layer.Linear(16 * 8 * 8, 64, bias=False),
        neuron.IFNode(),
        layer.Linear(64, 10, bias=False),
        neuron.IFNode(),
    )

    functional.set_step_mode(net, step_mode)

    instances_stdp = (layer.Conv2d, )

    stdp_learners = []

    for i in range(net.__len__()):
        if isinstance(net[i], instances_stdp):
            stdp_learners.append(
                learning.STDPLearner(step_mode=step_mode, synapse=net[i], sn=net[i+1], tau_pre=tau_pre, tau_post=tau_post,
                                    f_pre=f_weight, f_post=f_weight)
            )


    params_stdp = []
    for m in net.modules():
        if isinstance(m, instances_stdp):
            for p in m.parameters():
                params_stdp.append(p)

    params_stdp_set = set(params_stdp)
    params_gradient_descent = []
    for p in net.parameters():
        if p not in params_stdp_set:
            params_gradient_descent.append(p)

    optimizer_gd = Adam(params_gradient_descent, lr=lr)
    optimizer_stdp = SGD(params_stdp, lr=lr, momentum=0.)



    x_seq = (torch.rand([T, N, C, H, W]) > 0.5).float()
    target = torch.randint(low=0, high=10, size=[N])

    optimizer_gd.zero_grad()
    optimizer_stdp.zero_grad()

    y = net(x_seq).mean(0)
    loss = F.cross_entropy(y, target)
    loss.backward()



    optimizer_stdp.zero_grad()

    for i in range(stdp_learners.__len__()):
        stdp_learners[i].step(on_grad=True)

    optimizer_gd.step()
    optimizer_stdp.step()

    functional.reset_net(net)
    for i in range(stdp_learners.__len__()):
        stdp_learners[i].reset()







.. [#STDP] Bi, Guo-qiang, and Mu-ming Poo. "Synaptic modifications in cultured hippocampal neurons: dependence on spike timing, synaptic strength, and postsynaptic cell type." Journal of neuroscience 18.24 (1998): 10464-10472.

.. [#STDP_figure] Froemke, Robert C., et al. "Contribution of individual spikes in burst-induced long-term synaptic modification." Journal of neurophysiology (2006).

.. [#Trace] Morrison, Abigail, Markus Diesmann, and Wulfram Gerstner. "Phenomenological models of synaptic plasticity based on spike timing." Biological cybernetics 98.6 (2008): 459-478.