Clock driven: Use convolutional SNN to identify Fashion-MNIST
=============================================================================================

Author: `fangwei123456 <https://github.com/fangwei123456>`_

Translator: `YeYumin <https://github.com/YEYUMIN>`_

In this tutorial, we will build a convolutional spike neural network to classify the `Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`__ dataset.
The Fashion-MNIST dataset has the same format as the MNIST dataset, and both are ``1 * 28 * 28`` grayscale images.

Network structure
----------------------------

Most of the common convolutional neural networks in ANN are in the form of convolution + fully-connected layers.
We also use a similar structure in SNN. Let us import modules, inherit ``torch.nn.Module`` to define our network:

.. code-block:: python

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torchvision
    from spikingjelly.activation_based import neuron, functional, surrogate, layer
    from torch.utils.tensorboard import SummaryWriter
    import os
    import time
    import argparse
    import numpy as np
    from torch.cuda import amp
    _seed_ = 2020
    torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(_seed_)

    class PythonNet(nn.Module):
        def __init__(self, T):
            super().__init__()
            self.T = T

Then we add convolutional layers and a fully-connected layers to ``PythonNet``. We add two Conv-BN-Pooling::

.. code-block:: python

    self.conv = nn.Sequential(
        nn.Conv2d(1, 128, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(128),
        neuron.IFNode(surrogate_function=surrogate.ATan()),
        nn.MaxPool2d(2, 2),  # 14 * 14

        nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(128),
        neuron.IFNode(surrogate_function=surrogate.ATan()),
        nn.MaxPool2d(2, 2)  # 7 * 7
        )

The input with ``shape=[N, 1, 28, 28]`` will be converted to spikes with ``shape=[N, 128, 7, 7]``.

Such convolutional layers can actually function as an encoder: in the previous tutorial (classify MNIST), we used a
Poisson encoder to encode pictures into spikes. However, we can directly send the picture
to the SNN. In this case, the first spike neurons layer (SN) and the layers before SN can be regarded as an
auto-encoder with learnable parameters. Specifically, teh auto-encoder is composed of the following layers:

.. code-block:: python

    nn.Conv2d(1, 128, kernel_size=3, padding=1, bias=False),
    nn.BatchNorm2d(128),
    neuron.IFNode(surrogate_function=surrogate.ATan())

These layers receive images as input and output spikes, which can be regarded as an encoder.

Next, we add two fully-connected layers as the classifier. There are 10 neurons in output layer because the classes number
in Fashion-MNIST is 10.

.. code-block:: python

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 128 * 4 * 4, bias=False),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
            nn.Linear(128 * 4 * 4, 10, bias=False),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
        )

Now let us define the forward function.

.. code-block:: python

    def forward(self, x):
        x = self.static_conv(x)

        out_spikes_counter = self.fc(self.conv(x))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x))

        return out_spikes_counter / self.T

Avoid Duplicated Computing
--------------------------------

We can train this network directly, just like the previous MNIST classification. But if we re-examine the structure of
the network, we can find that some calculations are duplicated. For the first two layers of the network (the highlighted
part of the following codes):

.. code-block:: python
    :emphasize-lines: 2, 3

    self.conv = nn.Sequential(
            nn.Conv2d(1, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
            nn.MaxPool2d(2, 2),  # 14 * 14

            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
            nn.MaxPool2d(2, 2)  # 7 * 7
        )

The input images are static and do not change with ``t``. But they will be involved in ``for`` loop. At each time-step,
they will flow through the first two layers with the same calculation. We can remove them from ``for`` loop in time-steps.
The complete codes are:

.. code-block:: python

    class PythonNet(nn.Module):
        def __init__(self, T):
            super().__init__()
            self.T = T

            self.static_conv = nn.Sequential(
                nn.Conv2d(1, 128, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(128),
            )

            self.conv = nn.Sequential(
                neuron.IFNode(surrogate_function=surrogate.ATan()),
                nn.MaxPool2d(2, 2),  # 14 * 14

                nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(128),
                neuron.IFNode(surrogate_function=surrogate.ATan()),
                nn.MaxPool2d(2, 2)  # 7 * 7

            )
            self.fc = nn.Sequential(
                nn.Flatten(),
                nn.Linear(128 * 7 * 7, 128 * 4 * 4, bias=False),
                neuron.IFNode(surrogate_function=surrogate.ATan()),
                nn.Linear(128 * 4 * 4, 10, bias=False),
                neuron.IFNode(surrogate_function=surrogate.ATan()),
            )


        def forward(self, x):
            x = self.static_conv(x)

            out_spikes_counter = self.fc(self.conv(x))
            for t in range(1, self.T):
                out_spikes_counter += self.fc(self.conv(x))

            return out_spikes_counter / self.T

We put these stateless layers to ``self.static_conv`` to avoid duplicated calculations.

Training network
----------------------------
The complete codes are available at :class:`spikingjelly.activation_based.examples.conv_fashion_mnist`. The tarining arguments are:

.. code-block:: shell

    Classify Fashion-MNIST

    optional arguments:
      -h, --help            show this help message and exit
      -T T                  simulating time-steps
      -device DEVICE        device
      -b B                  batch size
      -epochs N             number of total epochs to run
      -j N                  number of data loading workers (default: 4)
      -data_dir DATA_DIR    root dir of Fashion-MNIST dataset
      -out_dir OUT_DIR      root dir for saving logs and checkpoint
      -resume RESUME        resume from the checkpoint path
      -amp                  automatic mixed precision training
      -cupy                 use cupy neuron and multi-step forward mode
      -opt OPT              use which optimizer. SDG or Adam
      -lr LR                learning rate
      -momentum MOMENTUM    momentum for SGD
      -lr_scheduler LR_SCHEDULER
                            use which schedule. StepLR or CosALR
      -step_size STEP_SIZE  step_size for StepLR
      -gamma GAMMA          gamma for StepLR
      -T_max T_MAX          T_max for CosineAnnealingLR

The checkpoint will be saved in the same level directory of the ``tensorboard`` log file. The server for training this
network uses `Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz` CPU and `GeForce RTX 2080 Ti` GPU.

.. code-block:: shell

    (pytorch-env) root@e8b6e4800dae4011eb0918702bd7ddedd51c-fangw1598-0:/# python -m spikingjelly.activation_based.examples.conv_fashion_mnist -opt SGD -data_dir /userhome/datasets/FashionMNIST/ -amp

    Namespace(T=4, T_max=64, amp=True, b=128, cupy=False, data_dir='/userhome/datasets/FashionMNIST/', device='cuda:0', epochs=64, gamma=0.1, j=4, lr=0.1, lr_scheduler='CosALR', momentum=0.9, opt='SGD', out_dir='./logs', resume=None, step_size=32)
    PythonNet(
      (static_conv): Sequential(
        (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv): Sequential(
        (0): IFNode(
          v_threshold=1.0, v_reset=0.0, detach_reset=False
          (surrogate_function): ATan(alpha=2.0, spiking=True)
        )
        (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): IFNode(
          v_threshold=1.0, v_reset=0.0, detach_reset=False
          (surrogate_function): ATan(alpha=2.0, spiking=True)
        )
        (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (fc): Sequential(
        (0): Flatten(start_dim=1, end_dim=-1)
        (1): Linear(in_features=6272, out_features=2048, bias=False)
        (2): IFNode(
          v_threshold=1.0, v_reset=0.0, detach_reset=False
          (surrogate_function): ATan(alpha=2.0, spiking=True)
        )
        (3): Linear(in_features=2048, out_features=10, bias=False)
        (4): IFNode(
          v_threshold=1.0, v_reset=0.0, detach_reset=False
          (surrogate_function): ATan(alpha=2.0, spiking=True)
        )
      )
    )
    Mkdir ./logs/T_4_b_128_SGD_lr_0.1_CosALR_64_amp.
    Namespace(T=4, T_max=64, amp=True, b=128, cupy=False, data_dir='/userhome/datasets/FashionMNIST/', device='cuda:0', epochs=64, gamma=0.1, j=4, lr=0.1, lr_scheduler='CosALR', momentum=0.9, opt='SGD', out_dir='./logs', resume=None, step_size=32)
    ./logs/T_4_b_128_SGD_lr_0.1_CosALR_64_amp
    epoch=0, train_loss=0.028124165828697957, train_acc=0.8188267895299145, test_loss=0.023525000348687174, test_acc=0.8633, max_test_acc=0.8633, total_time=16.86261749267578
    Namespace(T=4, T_max=64, amp=True, b=128, cupy=False, data_dir='/userhome/datasets/FashionMNIST/', device='cuda:0', epochs=64, gamma=0.1, j=4, lr=0.1, lr_scheduler='CosALR', momentum=0.9, opt='SGD', out_dir='./logs', resume=None, step_size=32)
    ./logs/T_4_b_128_SGD_lr_0.1_CosALR_64_amp
    epoch=1, train_loss=0.018544567498163536, train_acc=0.883613782051282, test_loss=0.02161250041425228, test_acc=0.8745, max_test_acc=0.8745, total_time=16.618073225021362
    Namespace(T=4, T_max=64, amp=True, b=128, cupy=False, data_dir='/userhome/datasets/FashionMNIST/', device='cuda:0', epochs=64, gamma=0.1, j=4, lr=0.1, lr_scheduler='CosALR', momentum=0.9, opt='SGD', out_dir='./logs', resume=None, step_size=32)

    ...

    ./logs/T_4_b_128_SGD_lr_0.1_CosALR_64_amp
    epoch=62, train_loss=0.0010829827882937538, train_acc=0.997512686965812, test_loss=0.011441250185668468, test_acc=0.9316, max_test_acc=0.933, total_time=15.976636171340942
    Namespace(T=4, T_max=64, amp=True, b=128, cupy=False, data_dir='/userhome/datasets/FashionMNIST/', device='cuda:0', epochs=64, gamma=0.1, j=4, lr=0.1, lr_scheduler='CosALR', momentum=0.9, opt='SGD', out_dir='./logs', resume=None, step_size=32)
    ./logs/T_4_b_128_SGD_lr_0.1_CosALR_64_amp
    epoch=63, train_loss=0.0010746361010835525, train_acc=0.9977463942307693, test_loss=0.01154562517106533, test_acc=0.9296, max_test_acc=0.933, total_time=15.83976149559021

After running 100 rounds of training, the correct rates on the training batch and test set are as follows:

.. image:: ../_static/tutorials/activation_based/4_conv_fashion_mnist/train.*
    :width: 100%

.. image:: ../_static/tutorials/activation_based/4_conv_fashion_mnist/test.*
    :width: 100%

After training for 64 epochs, the highest test set accuracy rate can reach 93.3%, which is a very good accuracy for
SNN. It is only slightly lower than ResNet18 (93.3%) with Normalization, random horizontal flip, random vertical flip,
random translation and random rotation in the BenchMark `Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`__.

Visual Encoder
------------------------------------
As we said in the above text, the first spike neurons layer (SN) and the layers before SN can be regarded as an auto-encoder with learnable parameters. Specifically, it is the highlighted part of our network shown below:

.. code-block:: python
    :emphasize-lines: 5, 6, 10

    class Net(nn.Module):
        def __init__(self, T):
            ...
            self.static_conv = nn.Sequential(
                nn.Conv2d(1, 128, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(128),
            )

            self.conv = nn.Sequential(
                neuron.IFNode(surrogate_function=surrogate.ATan()),
            ...
            )

Now let's take a look at the output spikes of the trained encoder. Let's create a new python file, import related
modules, and redefine a data loader with ``batch_size=1``, because we want to view pictures one by one:

.. code-block:: python

    from matplotlib import pyplot as plt
    import numpy as np
    from spikingjelly.activation_based.examples.conv_fashion_mnist import PythonNet
    from spikingjelly import visualizing
    import torch
    import torch.nn as nn
    import torchvision

    test_data_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.FashionMNIST(
            root=dataset_dir,
            train=False,
            transform=torchvision.transforms.ToTensor(),
            download=True),
        batch_size=1,
        shuffle=True,
        drop_last=False)

We load net from the checkpoint:

.. code-block:: python

    net = torch.load('./logs/T_4_b_128_SGD_lr_0.1_CosALR_64_amp/checkpoint_max.pth', 'cpu')['net']
    encoder = nn.Sequential(
        net.static_conv,
        net.conv[0]
    )
    encoder.eval()

Let us extract a image from the data set, send it to the encoder, and check the accumulated value :math:`\sum_{t} S_{t}` of the output spikes. In order to show clearly, we also normalize the pixel values of the output ``feature_map`` with linearly transformation to ``[0, 1]``.

.. code-block:: python

    with torch.no_grad():
        # every time all the data sets are traversed, test once on the test set
        for img, label in test_data_loader:
            fig = plt.figure(dpi=200)
            plt.imshow(img.squeeze().numpy(), cmap='gray')
            # Note that the size of the image input to the network is ``[1, 1, 28, 28]``, the 0th dimension is ``batch``, and the first dimension is ``channel``
            # therefore, when calling ``imshow``, first use ``squeeze()`` to change the size to ``[28, 28]``
            plt.title('Input image', fontsize=20)
            plt.xticks([])
            plt.yticks([])
            plt.show()
            out_spikes = 0
            for t in range(net.T):
                out_spikes += encoder(img).squeeze()
                # the size of encoder(img) is ``[1, 128, 28, 28]``，the same use ``squeeze()`` transform size to ``[128, 28, 28]``
                if t == 0 or t == net.T - 1:
                    out_spikes_c = out_spikes.clone()
                    for i in range(out_spikes_c.shape[0]):
                        if out_spikes_c[i].max().item() > out_spikes_c[i].min().item():
                            # Normalize each feature map to make the display clearer
                            out_spikes_c[i] = (out_spikes_c[i] - out_spikes_c[i].min()) / (out_spikes_c[i].max() - out_spikes_c[i].min())
                    visualizing.plot_2d_spiking_feature_map(out_spikes_c, 8, 16, 1, None)
                    plt.title('$\\sum_{t} S_{t}$ at $t = ' + str(t) + '$', fontsize=20)
                    plt.show()

The following figure shows two input iamges and the cumulative spikes :math:`\sum_{t} S_{t}` encoded by the encoder at ``t=0`` and ``t=7``:

.. image:: ../_static/tutorials/activation_based/4_conv_fashion_mnist/x0.*
    :width: 100%

.. image:: ../_static/tutorials/activation_based/4_conv_fashion_mnist/y00.*
    :width: 100%

.. image:: ../_static/tutorials/activation_based/4_conv_fashion_mnist/y07.*
    :width: 100%

.. image:: ../_static/tutorials/activation_based/4_conv_fashion_mnist/x1.*
    :width: 100%

.. image:: ../_static/tutorials/activation_based/4_conv_fashion_mnist/y10.*
    :width: 100%

.. image:: ../_static/tutorials/activation_based/4_conv_fashion_mnist/y17.*
    :width: 100%

It can be found that the cumulative spikes :math:`\sum_{t} S_{t}` are very similar to the origin images, indicating that the encoder has strong coding ability.
