#!/usr/bin/env python
# coding: utf-8

# # Spike Accumulation Forwarding for Effective Training of Spiking Neural Networks
# 
# ## 1. Setup
# ---
# 
# ### 1.1. Import

# In[ ]:


import math
import os
import random
from typing import Any

import matplotlib.pyplot as plt  # matplotlib==3.6.1
import numpy as np  # numpy==1.23.3
import pandas as pd  # pandas==1.5.0
import torch  # torch==1.10.0+cu113
import torch.nn as nn
import torch.nn.functional as F
import torchvision  # torchvision==0.11.1+cu113
import torchvision.transforms as transforms
from torch import Tensor

cwd = "./"
# data_dir
data_dir = os.path.join(cwd, "dataset")
# dataset = cifar10
out_dir = os.path.join(cwd, "logs")
# number of data loading workers
num_workers = 4
# device
device = torch.device("cuda")


# ### 1.2. Hyperparameters
# 
#  - `batch_size` : Batch Size
#  - `epochs` : Number of Epochs
#  - `lr` : Learning Rate
#  - `momentum` : momentum for SGD
#  - `T_max` : T_max for CosineAnnealingLR
#  - `weight_decay` : Weight Decay
#  - `loss_alpha` : α * MSE + (1 - α) * CE
#  - `t_step` : simulating time-steps
#  - `lif_lambda` : λ = 1 - (1 / τ)
#  - `cfg` : Model Architecture(VGG)

# In[ ]:


# batch size
batch_size = 128
# number of total epochs to run
epochs = 300
# learning rate
lr = 0.1
# momentum for SGD
momentum = 0.9
# lr_scheduler = CosALR
# T_max for CosineAnnealingLR
T_max = 300
# unused
weight_decay = 0.0
# loss_alpha
loss_alpha = 0.05
# simulating time-steps
t_step = 6
# λ = 1 - (1 / τ)
# Equivalent to IF when λ is 1.0
lif_lambda = 0.5
# Model Architecture(VGG11) -> 64C3-128C3-AP2-256C3-256C3-AP2-512C3-512C3-AP2-512C3-512C3-GAP-FC
cfg = [64, 128, "A", 256, 256, "A", 512, 512, "A", 512, 512]


# ### 1.3. Dataset
# #### 1.3.1. Data Augmentation

# In[ ]:


class Cutout(object):
    def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3)):
        """Initial setting

        Args:
            p: probability that the cutout operation will be performed.
            scale: range of proportion of cutout area against input.
            ratio: range of aspect ratio of cutout area.
        """

        self.p = p
        self.scale = scale
        self.ratio = ratio

    def __call__(self, x):
        """perform a cutout transform on the image.

        Args:
            x: input image
        Returns:
            Tensor: transformd image
        """
        if random.random() < self.p:
            _, img_h, img_w = x.shape
            area = img_h * img_w
            cutout_area = area * random.uniform(*self.scale)
            aspect_ratio = random.uniform(*self.ratio)
            h = int(round(math.sqrt(cutout_area * aspect_ratio)))
            w = int(round(math.sqrt(cutout_area / aspect_ratio)))
            # top left
            i = random.randint(0, img_h - h)
            j = random.randint(0, img_w - w)
            return transforms.functional.erase(x, i, j, h, w, random.random())
        return x


transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        Cutout(scale=(0.02, 0.4), ratio=(0.4, 2.5)),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)


transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)


# #### 1.3.2. CIFAR10

# In[ ]:


num_classes = 10

trainset = torchvision.datasets.CIFAR10(
    root=data_dir, train=True, download=True, transform=transform_train
)
train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers
)

testset = torchvision.datasets.CIFAR10(
    root=data_dir, train=False, download=False, transform=transform_test
)
test_loader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=num_workers
)

x, _ = train_loader.__iter__().next()
mean = np.array([0.4914, 0.4822, 0.4465]).reshape(1, 1, -1)
std = np.array([0.2023, 0.1994, 0.2010]).reshape(1, 1, -1)
plt.figure(figsize=(10, 10))
for i, xi in enumerate(x[:25]):
    plt.subplot(5, 5, i + 1)
    img = np.array(xi).transpose(1, 2, 0)
    img = np.clip(img * std + mean, 0, 1)
    plt.imshow(img)
plt.show()


# ## 2. Model Architecture
# 
# In this section, we will define the model architecture in the following four steps:
# 
#  - **Neuron** : Define Online-LIF(OLIF) and Spike Accumulation Forwarding(SAF)
#  - **Custom Module** : Define a custom module
#  - **VGG** : Define the model structure of VGG
#  - **Feadback-VGG** : Add Functionality of feedback connection
# 
# ### 2.1. Neuron
# 
# For simplicity, each neuron is implemented with $V_{th} = 1.0$. In addition, conversion to a Spike signal is done using the Heaviside function, which can be defined as follows:

# In[ ]:


def heaviside(x):
    """Heaviside

    Args:
        x: input

    Returns:
        Tensor: Spike
    """

    return (x >= 0).to(x)


# #### 2.1.2. Online Leaky-Integrate-and-Fire
# 
# Since OTTT requires spikes $s[t]$ and spike accumulation $\hat{a}[t]$, we define the OLIF that propagates these two.　OLIF can be defined as follows:

# In[ ]:


class OLIF(nn.Module):
    """Online Leaky-Integrate-and-Fire"""

    def __init__(
        self,
        l: float = 0.5,
        alpha: float = 4.0,
    ):
        """Initial setting

        Args:
            l: λ = 1 - (1 / τ)
            alpha: grad tuning param
        """

        super(OLIF, self).__init__()
        self.l = l
        self.alpha = torch.tensor(alpha)
        self.func = self.Function.apply
        self.reset()

    def forward(self, input: Tensor) -> Tensor:
        """Forward process

        Args:
            input: input(intensity of current)

        Returns:
            Tensor: [spike, accumulated spike]
        """

        # IF(I_t, u_t-1)
        self.u = self.l * self.u
        st = self.func(input, self.u, self.alpha)
        at = self.l * self.a + st
        self.a = at.clone().detach()
        # u_t-1 -> u_t
        self.u = (self.u + input - st).clone().detach()
        return torch.cat((st, at))

    def reset(self):
        """Reset u and a"""

        # membrane potential
        self.u = 0
        # accumulated spike
        self.a = 0

    class Function(torch.autograd.Function):
        """Autograd（Forward/Backward）"""

        @staticmethod
        def forward(ctx, i_t: Tensor, u: Tensor, alpha: Tensor) -> Tensor:
            """Forward process

            Args:
                i_t: intensity of current
                u: membrane potential
                alpha: grad tuning param

            Returns:
                Tensor: spike
            """
            ut = u + i_t
            x = ut - 1.0
            st = heaviside(x)
            if i_t.requires_grad:
                ctx.save_for_backward(x)
                ctx.alpha = alpha
            return st

        @staticmethod
        def backward(ctx: Any, grad_output: Tensor) -> Tensor:
            """Backward process

            Args:
                ctx: saved parameters
                grad_output: grad from output

            Returns:
                Tensor: grad to input
            """

            grad_x = None
            if ctx.needs_input_grad[0]:
                sgax = (ctx.saved_tensors[0] * ctx.alpha).sigmoid_()
                grad_x = grad_output * (1.0 - sgax) * sgax * ctx.alpha
            return grad_x, None, None


# #### 2.1.2. Spike Accumulation Forwarding
# 
# Our proposed SAF can be defined as follows:

# In[ ]:


class SAF(nn.Module):
    """Spike Accumulation Forwarding"""

    def __init__(
        self,
        l: float = 0.5,
        alpha: float = 4.0,
        spike: bool = False,
    ):
        """Initial setting

        Args:
            l: λ = 1 - (1 / τ)
            alpha: grad tuning param
            spike: True -> spike and accumulated spike, False -> accumulated spike
        """

        super(SAF, self).__init__()
        self.l = l
        self.alpha = torch.tensor(alpha)
        self.func = self.Function.apply
        self.spike = spike
        self.reset()

    def forward(self, V: Tensor) -> Tensor:
        """Forward process

        Args:
            V: accumulated potential

        Returns:
            Tensor: (spike and) accumulated spike
        """

        # accumulated spike
        la = self.l * self.a
        # spike
        st = self.func(V, la, self.alpha)
        # a_t-1 -> a_t
        at = la + st
        self.a = at.clone().detach()
        if self.spike:
            out = torch.cat((st, at))
        else:
            out = at
        return out

    def reset(self):
        """Reset u and a"""
        # accumulated spike
        self.a = 0

    class Function(torch.autograd.Function):
        """Autograd（Forward/Backward）"""

        @staticmethod
        def forward(ctx, V: Tensor, la: Tensor, alpha: Tensor) -> Tensor:
            """Forward process

            Args:
                V: accumulated potential
                la: accumulated spike
                alpha: grad tuning param

            Returns:
                Tensor: spike
            """

            x = V - la - 1.0
            st = heaviside(x)
            if V.requires_grad:
                ctx.save_for_backward(x)
                ctx.alpha = alpha
            return st

        @staticmethod
        def backward(ctx: Any, grad_output: Tensor) -> Tensor:
            """Backward process

            Args:
                ctx: saved parameters
                grad_output: grad from output

            Returns:
                Tensor: grad to input
            """

            grad_x = None
            if ctx.needs_input_grad[0]:
                sgax = (ctx.saved_tensors[0] * ctx.alpha).sigmoid_()
                grad_x = grad_output * (1.0 - sgax) * sgax * ctx.alpha
            return grad_x, None, None, None


# ### 2.2. Custom Module
# 
#  - **OutputSwap** : Process while swapping spike and rate
#  - **LSUM** : Compute a weighted sum considering λ
#  - **Leaky-Bias** : Since bias can be considered as input of each time, it is added using LSUM
#  - **sWSConv2d** : Unlike BN, which standardizes the dataset, sWS standardizes weights
#  - **Combine CE and MSE** : Define combine cross-entropy (CE) loss and mean-square-error (MSE) loss
# 
# #### 2.2.1. OutputSwap

# In[ ]:


class OutputSwap(nn.Module):
    """Process while swapping spike and rate"""

    def __init__(self, f):
        """Initial setting

        Args:
            f: function (ex. Conv2D, Linear, etc)
        """

        super(OutputSwap, self).__init__()
        self.f = f
        self.func = self.Function.apply

    def forward(self, x):
        """Forward process

        Args:
            x: [spike, rate]

        Returns:
            Tensor: Forward -> op(spike), Backward -> grad of rate
        """

        spike, rate = torch.chunk(x, 2, dim=0)
        if self.training:
            with torch.no_grad():
                y1 = self.f(spike).detach()
            y2 = self.f(rate)
            y = self.func(y2, y1)
        else:
            y = self.f(spike)
        return y

    class Function(torch.autograd.Function):
        """Autograd（Forward/Backward）"""

        @staticmethod
        def forward(ctx, x1: Tensor, x2: Tensor) -> Tensor:
            """Forward process

            Args:
                x1: used grad
                x2: used value

            Returns:
                Tensor: x2
            """

            return x2

        @staticmethod
        def backward(ctx: Any, grad_output: Tensor) -> Tensor:
            """Backward process

            Args:
                ctx: saved parameters
                grad_output: grad from x1

            Returns:
                Tensor: grad_output
            """

            return grad_output, None


# #### 2.2.2. Leaky-SUM

# In[ ]:


class LSUM(nn.Module):
    """Leaky-SUM"""

    def __init__(
        self,
        l: float = 0.5,
    ):
        """Initial setting

        Args:
            l: λ = 1 - (1 / τ)
        """

        super(LSUM, self).__init__()
        # λ = 1 - (1 / τ)
        self.l = l
        self.lx = 0
        self.reset()

    def forward(self, x: Tensor) -> Tensor:
        """Forward process

        Args:
            x: input data

        Returns:
            Tensor: accumulated x
        """

        y = self.l * self.lx + x
        self.lx = y.clone().detach()
        return y

    def reset(self):
        """Reset accumulated x"""

        self.lx = 0


# #### 2.2.3. Leaky-Bias

# In[ ]:


class LBias2d(nn.Module):
    """Leaky-Bias2d"""

    def __init__(
        self,
        channels: int,
        l: float = 0.5,
    ):
        """Initial setting

        Args:
            channels: input channels
            l: λ = 1 - (1 / τ)
        """
        super(LBias2d, self).__init__()
        self.l = l
        self.bias = nn.Parameter(torch.zeros(channels))
        self.b = 0

    def forward(self, x):
        """Forward process

        Args:
            x: input

        Returns:
            Tensor: accumulated bias added to x
        """

        b = self.l * self.b + self.bias
        self.b = b.clone().detach()
        return x + b.reshape(1, -1, 1, 1)

    def reset(self):
        """Reset accumulated bias"""

        self.b = 0


class LBias(nn.Module):
    """Leaky-Bias"""

    def __init__(
        self,
        channels: int,
        l: float = 0.5,
    ):
        """Initial setting

        Args:
            channels: input channels
            l: λ = 1 - (1 / τ)
        """
        super(LBias, self).__init__()
        self.l = l
        self.bias = nn.Parameter(torch.zeros(channels))
        self.b = 0

    def forward(self, x):
        """Forward process

        Args:
            x: input

        Returns:
            Tensor: accumulated bias added to x
        """

        b = self.l * self.b + self.bias
        self.b = b.clone().detach()
        return x + b.reshape(1, -1)

    def reset(self):
        """Reset accumulated bias"""

        self.b = 0


# #### 2.2.4. sWSConv2d

# In[ ]:


class ScaledWeightStandardization(nn.Module):
    """Scaled weight standardization"""

    def __init__(
        self,
        n,
        dim,
        eps=1e-4,
    ):
        """Initial setting

        Args:
            n: number of inputs
            dim: dimension to normalize
            eps: a value added to the denominator for numerical stability
        """

        super(ScaledWeightStandardization, self).__init__()
        self.n = n
        self.dim = dim
        self.eps = eps

    def forward(self, x):
        """Forward process

        Args:
            x: weight

        Returns:
            Tensor: normalized weight
        """

        mean = torch.mean(x, dim=self.dim, keepdims=True)
        var = torch.var(x, dim=self.dim, keepdims=True)
        return (x - mean) / torch.sqrt(var * self.n + self.eps)


class sWSConv2d(nn.Conv2d):
    """Convolution2D with scaled weight standardization"""

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
        eps=1e-4,
    ):
        """Initial setting

        Args:
            in_channels: input channels
            out_channels: output channels
            kernel_size: kernel size
            stride: stride
            padding: padding
            dilation: spacing between kernel elements
            groups: num of blocked connections from in to out
            bias: if True, adds a learnable bias
            eps: a value added to the denominator for numerical stability
        """

        super(sWSConv2d, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
        )
        # scaled weight standardization
        n = np.prod(self.weight.shape[1:])
        self.sws = ScaledWeightStandardization(n, (1, 2, 3))
        # learnable scaling factor for the weights
        self.gamma = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))

    def forward(self, x):
        """Forward process

        Args:
            x: input

        Returns:
            Tensor: conv2d(x)
        """

        weight = self.gamma * self.sws(self.weight)
        return F.conv2d(
            x,
            weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )


class Gamma(nn.Module):
    """γ = 1/σH ≈ 2.74"""

    def __init__(self, gamma=2.74):
        """Initial setting

        Args:
            gamma: constant
        """

        super(Gamma, self).__init__()
        self.gamma = gamma

    def forward(self, x):
        """Forward process

        Args:
            x: input

        Returns:
            Tensor: gamma * x
        """

        return self.gamma * x


# #### 2.2.5. Combine CE and MSE

# In[ ]:


class CombineCEandMSE(nn.Module):
    """combine cross-entropy (CE) loss and mean-square-error (MSE) loss."""

    def __init__(self, alpha, num_classes):
        """Initial setting

        Args:
            alpha: α * MSE + (1 - α) * CE
            num_classes: number of classes
        """
        super(CombineCEandMSE, self).__init__()
        self.mse = nn.MSELoss(reduction="mean")
        self.a = alpha
        self.n = num_classes

    def forward(self, y, t):
        """Forward process

        Args:
            y: model output
            t: target label

        Returns:
            Tensor: α * MSE + (1 - α) * CE
        """

        one_hot = F.one_hot(t, self.n).float()
        loss1 = self.mse(y, one_hot)
        loss2 = F.cross_entropy(y, t)
        return self.a * loss1 + (1 - self.a) * loss2


# ### 2.3. VGG
# 
# We leverage the VGG network architecture (64C3-128C3-AP2-256C3-256C3-AP2-512C3-512C3-AP2-512C3-512C3-GAP-FC) for experiments on CIFAR-10. We then define three models (**OTTT-based VGG**, **SAF-based VGG** and **Output Leaky-FR model**) with equivalent forward propagation.
# 
# #### 2.3.1. OTTT-based VGG
# 
# Combining this model with **Training_E** and **Training_A** defined in 3.1 results in $OTTT_O$ and $OTTT_A$.

# In[ ]:


class OTTTVGG(nn.Module):
    """OTTT-based VGG model"""

    def __init__(self, l):
        """Initial setting

        Args:
            l: λ = 1 - (1 / τ)
        """

        super(OTTTVGG, self).__init__()
        layers = []
        in_ch = 3
        for i, x in enumerate(cfg):
            if x == "A":
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            elif x == "M":
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv = sWSConv2d(in_ch, x, 3, 1, 1)
                if i != 0:
                    conv = OutputSwap(conv)
                layers += [conv, OLIF(l=l), Gamma()]
                in_ch = x
        layers += [
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            OutputSwap(nn.Linear(cfg[-1], 10)),
        ]
        self.features = nn.Sequential(*layers)

    def forward(self, x, init):
        """Forward process

        Args:
            x: input
            init: if True, reset all u and a in OLIF

        Returns:
            Tensor: output
        """

        if init:
            self.reset()
        return self.features(x)

    def reset(self):
        """Reset all u, b and a in Leaky-XXX Layer"""

        for f in self.modules():
            if (
                isinstance(f, LSUM)
                or isinstance(f, OLIF)
                or isinstance(f, SAF)
                or isinstance(f, LBias2d)
                or isinstance(f, LBias)
            ):
                f.reset()


# #### 2.3.2. SAF-based VGG
# 
# Combining this model with **Training_E** defined in 3.1 results in SAF-E.

# In[ ]:


class SAFVGG(nn.Module):
    """SAF-based VGG model"""

    def __init__(self, l):
        """Initial setting

        Args:
            l: λ = 1 - (1 / τ)
        """

        super(SAFVGG, self).__init__()
        layers = []
        in_ch = 3
        for i, x in enumerate(cfg):
            if x == "A":
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            elif x == "M":
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                if i == 0:
                    layers += [
                        sWSConv2d(in_ch, x, 3, 1, 1),
                        LSUM(l=l),
                        SAF(l=l),
                        Gamma(),
                    ]
                else:
                    if i == (len(cfg) - 1):
                        act = SAF(l=l, spike=True)
                    else:
                        act = SAF(l=l)
                    layers += [
                        sWSConv2d(in_ch, x, 3, 1, 1, bias=False),
                        LBias2d(x, l=l),
                        act,
                        Gamma(),
                    ]
                in_ch = x
        layers += [
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            OutputSwap(nn.Linear(cfg[-1], 10)),
        ]
        self.features = nn.Sequential(*layers)

    def forward(self, x, init):
        """Forward process

        Args:
            x: input
            init: if True, reset all u, b and a in Leaky-XXX Layer

        Returns:
            Tensor: output
        """

        if init:
            self.reset()
        return self.features(x)

    def reset(self):
        """Reset all u, b and a in Leaky-XXX Layer"""

        for f in self.modules():
            if (
                isinstance(f, LSUM)
                or isinstance(f, OLIF)
                or isinstance(f, SAF)
                or isinstance(f, LBias2d)
                or isinstance(f, LBias)
            ):
                f.reset()


# #### 2.3.3. Output Leaky-FR
# 
# Combining this model with **Training_F** defined in 3.1 results in SAF-F.

# In[ ]:


class SAFVGG_FR(nn.Module):
    """Our VGG Model using SAF"""

    def __init__(self, l):
        """Initial setting

        Args:
            l: λ = 1 - (1 / τ)
        """

        super(SAFVGG_FR, self).__init__()
        layers = [LSUM(l=l)]
        in_ch = 3
        for i, x in enumerate(cfg):
            if x == "A":
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            elif x == "M":
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [
                    sWSConv2d(in_ch, x, 3, 1, 1, bias=False),
                    LBias2d(x, l=l),
                    SAF(l=l),
                    Gamma(),
                ]
                in_ch = x
        layers += [
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(cfg[-1], 10, bias=False),
            LBias(10, l=l),
        ]
        self.features = nn.Sequential(*layers)

    def forward(self, x, init):
        """Forward process

        Args:
            x: input
            init: if True, reset all u, b and a in Leaky-XXX Layer

        Returns:
            Tensor: output
        """

        if init:
            self.reset()
        return self.features(x)

    def reset(self):
        """Reset all u, b and a in Leaky-XXX Layer"""

        for f in self.modules():
            if (
                isinstance(f, LSUM)
                or isinstance(f, OLIF)
                or isinstance(f, SAF)
                or isinstance(f, LBias2d)
                or isinstance(f, LBias)
            ):
                f.reset()


# ### 2.4. Feadback-VGG
# 
# Add feedback connection to **2.3.VGG**.
# 
# #### 2.4.1. OTTT-based Feadback-VGG

# In[ ]:


class FOTTTVGG(nn.Module):
    """OTTT-based Feadback-VGG model"""

    def __init__(self, l):
        """Initial setting

        Args:
            l: λ = 1 - (1 / τ)
        """

        super(FOTTTVGG, self).__init__()
        scale_factor = 2 ** cfg.count("A") * 2 ** cfg.count("M")
        self.up = nn.Upsample(scale_factor=scale_factor, mode="nearest")
        self.fb = OutputSwap(
            nn.Conv2d(cfg[-1], cfg[0], kernel_size=3, padding=1, stride=1, bias=False)
        )
        in_ch = 3
        layers = []
        for i, x in enumerate(cfg):
            if x == "A":
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            elif x == "M":
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv = sWSConv2d(in_ch, x, 3, 1, 1)
                if i != 0:
                    conv = OutputSwap(conv)
                layers += [conv, OLIF(l=l), Gamma()]
                in_ch = x
        layers += [
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            OutputSwap(nn.Linear(cfg[-1], 10)),
        ]
        self.input_layer = nn.Sequential(*layers[:1])
        self.features = nn.Sequential(*layers[1:-3])
        self.output_layer = nn.Sequential(*layers[-3:])

    def forward(self, x, init):
        """Forward process

        Args:
            x: input
            init: if True, reset all u and a in OLIF

        Returns:
            Tensor: output
        """

        if init:
            self.reset()
            fb = 0
        else:
            fb = self.fb(self.up(self.fb_features))
        h = self.input_layer(x) + fb
        h = self.features(h)
        self.fb_features = h.clone().detach()
        return self.output_layer(h)

    def reset(self):
        """Reset all u, b and a in Leaky-XXX Layer"""

        for f in self.modules():
            if (
                isinstance(f, LSUM)
                or isinstance(f, OLIF)
                or isinstance(f, SAF)
                or isinstance(f, LBias2d)
                or isinstance(f, LBias)
            ):
                f.reset()


# #### 2.4.2. SAF-based Feadback-VGG

# In[ ]:


class FSAFVGG(nn.Module):
    """SAF-based Feadback-VGG model"""

    def __init__(self, l):
        """Initial setting

        Args:
            l: λ = 1 - (1 / τ)
        """

        super(FSAFVGG, self).__init__()
        scale_factor = 2 ** cfg.count("A") * 2 ** cfg.count("M")
        self.up = nn.Upsample(scale_factor=scale_factor, mode="nearest")
        self.fb = nn.Conv2d(
            cfg[-1], cfg[0], kernel_size=3, padding=1, stride=1, bias=False
        )
        in_ch = 3
        layers = []
        for i, x in enumerate(cfg):
            if x == "A":
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            elif x == "M":
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                if i == 0:
                    conv = sWSConv2d(in_ch, x, 3, 1, 1)
                    layers += [conv, LSUM(l=l), SAF(l=l), Gamma()]
                else:
                    conv = sWSConv2d(in_ch, x, 3, 1, 1, bias=False)
                    bias = LBias2d(x, l=l)
                    if i != (len(cfg) - 1):
                        layers += [conv, bias, SAF(l=l), Gamma()]
                    else:
                        layers += [conv, bias, SAF(l=l, spike=True), Gamma()]
                in_ch = x
        layers += [
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            OutputSwap(nn.Linear(cfg[-1], 10)),
        ]
        self.input_layer = nn.Sequential(*layers[:2])
        self.features = nn.Sequential(*layers[2:-3])
        self.output_layer = nn.Sequential(*layers[-3:])

    def forward(self, x, init):
        """Forward process

        Args:
            x: input
            init: if True, reset all u, b and a in Leaky-XXX Layer

        Returns:
            Tensor: output
        """

        if init:
            self.reset()
            fb = 0
        else:
            fb = self.fb(self.up(self.fb_features))
        h = self.input_layer(x) + fb
        h = self.features(h)
        s, r = torch.chunk(h, 2, dim=0)
        self.fb_features = r.clone().detach()
        return self.output_layer(h)

    def reset(self):
        """Reset all u, b and a in Leaky-XXX Layer"""

        for f in self.modules():
            if (
                isinstance(f, LSUM)
                or isinstance(f, OLIF)
                or isinstance(f, SAF)
                or isinstance(f, LBias2d)
                or isinstance(f, LBias)
            ):
                f.reset()


# #### 2.4.3. Output Leaky-FR

# In[ ]:


class FSAFVGG_FR(nn.Module):
    """Our VGG Model using SAF"""

    def __init__(self, l):
        """Initial setting

        Args:
            l: λ = 1 - (1 / τ)
        """

        super(FSAFVGG_FR, self).__init__()
        scale_factor = 2 ** cfg.count("A") * 2 ** cfg.count("M")
        self.up = nn.Upsample(scale_factor=scale_factor, mode="nearest")
        self.fb = nn.Conv2d(
            cfg[-1], cfg[0], kernel_size=3, padding=1, stride=1, bias=False
        )
        in_ch = 3
        layers = [LSUM(l=l)]
        for i, x in enumerate(cfg):
            if x == "A":
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            elif x == "M":
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv = sWSConv2d(in_ch, x, 3, 1, 1, bias=False)
                bias = LBias2d(x, l=l)
                layers += [conv, bias, SAF(l=l), Gamma()]
                in_ch = x
        layers += [
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(cfg[-1], 10),
            LBias(10, l=l),
        ]
        self.input_layer = nn.Sequential(*layers[:3])
        self.features = nn.Sequential(*layers[3:-4])
        self.output_layer = nn.Sequential(*layers[-4:])

    def forward(self, x, init):
        """Forward process

        Args:
            x: input
            init: if True, reset all u, b and a in Leaky-XXX Layer

        Returns:
            Tensor: output
        """

        if init:
            self.reset()
            fb = 0
        else:
            fb = self.fb(self.up(self.fb_features))
        h = self.input_layer(x) + fb
        h = self.features(h)
        self.fb_features = h.clone().detach()
        return self.output_layer(h)

    def reset(self):
        """Reset all u, b and a in Leaky-XXX Layer"""

        for f in self.modules():
            if (
                isinstance(f, LSUM)
                or isinstance(f, OLIF)
                or isinstance(f, SAF)
                or isinstance(f, LBias2d)
                or isinstance(f, LBias)
            ):
                f.reset()


# ## 3. Training and Validation
# 
# In this section, we will define the model training and validation process.
# 
# ### 3.1. Training Process
# 
# We define three types of training methods.
# 
#  - **Training_E** : Update the parameters with the gradients at each time
#  - **Training_A** : Update the parameters with the sum of the gradients calculated at each time
#  - **Training_F** : Update the parameters with the gradients at final time
# 
# #### 3.1.1. Training_E

# In[ ]:


def train_e(loader, model, criterion, optimizer, t_step, epoch):
    """Train-E

    The training process computes and updates the gradients each time.

    Args:
        loader: DataLoder
        model: training model
        criterion: loss function
        optimizer: optimizer
        t_step: simulating time-steps
        epoch: current epoch
    """

    model.train()
    total_loss = 0
    total_acc = 0
    total_samples = 0
    for i, (x, t) in enumerate(loader):
        x = x.cuda()
        t = t.cuda()
        Y = 0
        for j in range(t_step):
            optimizer.zero_grad()
            y = model(x, j == 0)
            Y = Y + y
            loss = criterion(y, t) / t_step
            loss.backward()
            total_loss += loss.item() * t.numel()
            optimizer.step()
        total_samples += t.numel()
        total_acc += (Y.argmax(1) == t).float().sum().item()
    total_loss /= total_samples
    total_acc /= total_samples
    print(f"[Train] Accuracy: {(100*total_acc):>0.1f}%, Loss: {total_loss:>8f}")
    return total_loss, total_acc


# #### 3.1.2. Training_A

# In[ ]:


def train_a(loader, model, criterion, optimizer, t_step, epoch):
    """Train-A

    The training process compute and update the sum of gradients at each time.

    Args:
        loader: DataLoder
        model: training model
        criterion: loss function
        optimizer: optimizer
        t_step: simulating time-steps
        epoch: current epoch
    """

    model.train()
    total_loss = 0
    total_acc = 0
    total_samples = 0
    for i, (x, t) in enumerate(loader):
        x = x.cuda()
        t = t.cuda()
        Y = 0
        optimizer.zero_grad()
        for j in range(t_step):
            y = model(x, j == 0)
            Y = Y + y.clone().detach()
            loss = criterion(y, t) / t_step
            loss.backward()
            total_loss += loss.item() * t.numel()
        optimizer.step()
        total_samples += t.numel()
        total_acc += (Y.argmax(1) == t).float().sum().item()
    total_loss /= total_samples
    total_acc /= total_samples
    print(f"[Train] Accuracy: {(100*total_acc):>0.1f}%, Loss: {total_loss:>8f}")
    return total_loss, total_acc


# #### 3.1.3. Training_F

# In[ ]:


def train_f(loader, model, criterion, optimizer, t_step, epoch, l=lif_lambda):
    """Train-F

    The training process computes and updates the gradient at final time step.

    Args:
        loader: DataLoder
        model: training model
        criterion: loss function
        optimizer: optimizer
        t_step: simulating time-steps
        epoch: current epoch
        l: λ = 1 - (1 / τ)
    """

    model.train()
    total_loss = 0
    total_acc = 0
    total_samples = 0
    for i, (x, t) in enumerate(loader):
        x = x.cuda()
        t = t.cuda()
        a = 0
        optimizer.zero_grad()
        for j in range(t_step):
            a = 1 + a * l
            y = model(x, j == 0) / a
            if (j + 1) == t_step:
                loss = criterion(y, t)
        loss.backward()
        total_loss += loss.item() * t.numel()
        optimizer.step()
        total_samples += t.numel()
        total_acc += (y.argmax(1) == t).float().sum().item()
    total_loss /= total_samples
    total_acc /= total_samples
    print(f"[Train] Accuracy: {(100*total_acc):>0.1f}%, Loss: {total_loss:>8f}")
    return total_loss, total_acc


# ### 3.2. Validation Process

# In[ ]:


def validation(loader, model, criterion, t_step, epoch):
    """Validation Process

    Args:
        loader: DataLoder
        model: training model
        criterion: loss function
        t_step: simulating time-steps
        epoch: current epoch
    """

    model.eval()
    total_loss = 0
    total_acc = 0
    total_samples = 0
    with torch.no_grad():
        for i, (x, t) in enumerate(loader):
            x = x.cuda()
            t = t.cuda()
            Y = 0
            for j in range(t_step):
                y = model(x, j == 0)
                Y = Y + y.clone().detach()
            loss = criterion(y, t)
            total_loss += loss.item() * t.numel()
            total_samples += t.numel()
            total_acc += (Y.argmax(1) == t).float().sum().item()
    total_loss /= total_samples
    total_acc /= total_samples
    print(f"[ Test] Accuracy: {(100*total_acc):>0.1f}%, Loss: {total_loss:>8f}")
    return total_loss, total_acc


# ### 3.3. Optimization process

# In[ ]:


def optimize_model(model, out_path, train):
    """Optimization of model parameters

    Args:
        model: training model
        out_path: model output path
        train: training process (E, A or F)
    """

    last_model = os.path.join(out_path, "epoch_" + str(epochs).zfill(4) + ".pth")
    if os.path.exists(last_model):
        print(last_model, "is exists.")
    else:
        os.makedirs(out_path, exist_ok=True)
        criterion = CombineCEandMSE(loss_alpha, num_classes)
        optimizer = torch.optim.SGD(
            model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)
        columns = [
            "Epoch",
            "Train Loss",
            "Train ACC",
            "Test Loss",
            "Test ACC",
        ]
        log = []
        for epoch in range(1, epochs + 1):
            print(f"Epoch {epoch}\n-------------------------------")
            train_loss, train_acc = train(
                train_loader,
                model,
                criterion,
                optimizer,
                t_step,
                epoch,
            )
            test_loss, test_acc = validation(
                test_loader,
                model,
                criterion,
                t_step,
                epoch,
            )
            log += [
                [
                    epoch,
                    train_loss,
                    train_acc,
                    test_loss,
                    test_acc,
                ]
            ]
            pd.DataFrame(log, columns=columns).to_csv(os.path.join(out_path, "log.csv"))
            scheduler.step()
            if (epoch % 100) == 0:
                torch.save(
                    {"state_dict": model.state_dict()},
                    os.path.join(out_path, "epoch_" + str(epoch).zfill(4) + ".pth"),
                )


# ## 4. Experiment
# 
# In the previous sections, we have defined the functions necessary for the experiment. In this section, we experimentally compare these methods. We trained SAF on the CIFAR-10 dataset (Krizhevsky and Hinton, 2009) and inferred with SNN composed of LIF neurons. 
# 
# ### 4.1. Training

# In[ ]:


def initialize_weights(model):
    """Initialize weight parameters

    Args:
        model: model
    """

    for m in model.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, sWSConv2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, LBias2d) or isinstance(m, LBias):
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)


def get_training_list(models, process, num):
    """Get a list of training settings

    Args:
        models: models
        process: training process
        num: number of models

    Returns:
        list: training settings
    """

    out = np.meshgrid(models, process, np.arange(1, num + 1))
    return np.array(out).reshape(3, -1).T


# Number of models to create
N = 1
training_list = []
# Training_E
models = [SAFVGG, OTTTVGG, FSAFVGG, FOTTTVGG]
updates = [train_e]
training_list.append(get_training_list(models, updates, N))
# Training_a
models = [OTTTVGG]
updates = [train_a]
# Training_F
training_list.append(get_training_list(models, updates, N))
models = [SAFVGG_FR, FSAFVGG_FR]
updates = [train_f]
training_list.append(get_training_list(models, updates, N))
# Concatenate
training_list = np.concatenate(training_list, axis=0)
for param in training_list:
    model, train, num = param
    tag = train.__name__.upper() + "-" + model.__name__
    num = str(num).zfill(2)
    out_path = os.path.join(out_dir, tag, num)
    model = model(lif_lambda)
    initialize_weights(model)
    model.cuda()
    optimize_model(model, out_path, train)
    del model
    torch.cuda.empty_cache()


# ### 4.2. Test
# #### 4.2.1. LIF VGG
# 
# A simple SNN composed of LIF neurons can be defined as:

# In[ ]:


class LIF(nn.Module):
    """Leaky-Integrate-and-Fire"""

    def __init__(self, l: float = 0.0):
        """Initial setting

        Args:
            l: λ = 1 - (1 / τ)
        """

        super(LIF, self).__init__()
        self.l = l
        self.reset()

    def forward(self, input: Tensor) -> Tensor:
        """Forward process

        Args:
            input: input(intensity of current)

        Returns:
            Tensor: [spike, accumulated spike]
        """

        # IF(I_t, u_t-1)
        self.u = self.l * self.u + input
        self.st = heaviside(self.u - 1.0)
        # u_t-1 -> u_t
        self.u = self.u - self.st
        return self.st

    def reset(self):
        """Reset u and a"""
        # membrane potential
        self.u = 0
        self.st = None


class VGG(nn.Module):
    """VGG Model"""

    def __init__(self, l):
        """Initial setting

        Args:
            l: λ = 1 - (1 / τ)
        """

        super(VGG, self).__init__()
        in_ch = 3
        layers = []
        for i, x in enumerate(cfg):
            if x == "A":
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            elif x == "M":
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv = sWSConv2d(in_ch, x, 3, 1, 1)
                layers += [conv, LIF(l=l), Gamma()]
                in_ch = x
        layers += [
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(cfg[-1], 10),
        ]
        self.features = nn.Sequential(*layers)

    def forward(self, x, init):
        """Forward process

        Args:
            x: input
            init: if True, reset all u and a in LIF

        Returns:
            Tensor: output
        """

        if init:
            self.reset()
        return self.features(x)

    def reset(self):
        """Reset all u, b and a in Leaky-XXX Layer"""

        for f in self.modules():
            if isinstance(f, LIF):
                f.reset()


class FVGG(nn.Module):
    """Feadback-VGG"""

    def __init__(self, l):
        """Initial setting

        Args:
            l: λ = 1 - (1 / τ)
        """

        super(FVGG, self).__init__()
        scale_factor = 2 ** cfg.count("A") * 2 ** cfg.count("M")
        self.up = nn.Upsample(scale_factor=scale_factor, mode="nearest")
        self.fb = nn.Conv2d(
            cfg[-1], cfg[0], kernel_size=3, padding=1, stride=1, bias=False
        )
        in_ch = 3
        layers = []
        for i, x in enumerate(cfg):
            if x == "A":
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            elif x == "M":
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv = sWSConv2d(in_ch, x, 3, 1, 1)
                layers += [conv, LIF(l=l), Gamma()]
                in_ch = x
        layers += [
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(cfg[-1], 10),
        ]
        self.input_layer = nn.Sequential(*layers[:1])
        self.features = nn.Sequential(*layers[1:-3])
        self.output_layer = nn.Sequential(*layers[-3:])

    def forward(self, x, init):
        """Forward process

        Args:
            x: input
            init: if True, reset all u and a in OLIF

        Returns:
            Tensor: output
        """

        if init:
            self.reset()
            fb = 0
        else:
            fb = self.fb(self.up(self.fb_features))
        h = self.input_layer(x) + fb
        h = self.features(h)
        self.fb_features = h.clone().detach()
        return self.output_layer(h)

    def reset(self):
        """Reset all u, b and a in Leaky-XXX Layer"""

        for f in self.modules():
            if isinstance(f, LIF):
                f.reset()


# #### 4.2.2. Conversion Process

# In[ ]:


def convert(base_model, lif_lambda, target):
    """Convert OTTT or SAF model to simple LIF model

    Args:
        base_model: OTTT or SAF model
        lif_lambda:
        target:

    Returns:
        nn.Module: target_model
    """

    model = target(lif_lambda)
    params = get_params(model)
    base_params = get_params(base_model)
    for base_param, param in zip(base_params, params):
        param.data = base_param.data.clone().detach()
    return model


def get_params(model):
    """Get list of model parameters

    Args:
        model: model

    Returns:
        list: model parameters
    """

    params = []
    for m in model.modules():
        if (
            isinstance(m, sWSConv2d)
            or isinstance(m, nn.Linear)
            or isinstance(m, nn.Conv2d)
        ):
            params.append(m.weight)
            if isinstance(m, sWSConv2d):
                if m.gamma is not None:
                    params.append(m.gamma)
            if m.bias is not None:
                params.append(m.bias)
        if isinstance(m, LBias2d) or isinstance(m, LBias):
            params.append(m.bias)
    return params


# #### 4.2.3. Test Process

# In[ ]:


def get_fr(model):
    """get firing rate

    Args:
        models: models

    Returns:
        list: firing rate
    """

    fr = []
    for f in model.modules():
        if isinstance(f, LIF):
            fr.append(torch.mean(f.st).item())
    return np.array(fr)


def test(loader, model, criterion, t_step, epoch):
    """Test Process

    Args:
        loader: DataLoder
        model: training model
        criterion: loss function
        t_step: simulating time-steps
        epoch: current epoch
    """

    model.eval()
    total_loss = 0
    total_acc = 0
    total_samples = 0
    total_fr = 0
    for i, (x, t) in enumerate(loader):
        x = x.float().cuda()
        t = t.cuda()
        Y = 0
        for j in range(t_step):
            y = model(x, j == 0)
            Y = Y + y.clone().detach()
            total_fr = total_fr + get_fr(model) * t.numel() / t_step
        loss = criterion(y, t)
        total_loss += loss.item() * t.numel()
        total_samples += t.numel()
        total_acc += (Y.argmax(1) == t).float().sum().item()
    total_loss /= total_samples
    total_acc /= total_samples
    total_fr /= total_samples
    print(f"[ Test] Accuracy: {(100*total_acc):>0.1f}%, Loss: {total_loss:>8f}")
    return total_loss, total_acc, total_fr


# #### 4.2.4. Run

# In[28]:


criterion = CombineCEandMSE(loss_alpha, num_classes)
pth = "epoch_0300.pth"
nums = [str(x + 1).zfill(2) for x in range(N)]
path_list = [
    "./logs/TRAIN_E-OTTTVGG",
    "./logs/TRAIN_A-OTTTVGG",
    "./logs/TRAIN_E-SAFVGG",
    "./logs/TRAIN_F-SAFVGG_FR",
    "./logs/TRAIN_E-FOTTTVGG",
    "./logs/TRAIN_E-FSAFVGG",
]
structure_list = [
    OTTTVGG,
    OTTTVGG,
    SAFVGG,
    SAFVGG_FR,
    FOTTTVGG,
    FSAFVGG,
]
logs = []
for path, structure in zip(path_list, structure_list):
    # creat base (SAF or OTTT) model
    base_model = structure(lif_lambda)
    name = path.split("/")[-1]
    for num in nums:
        print(name, num)
        # load state to base model
        model_path = os.path.join(path, num, pth)
        base_model.load_state_dict(torch.load(model_path)["state_dict"])
        # base model state to LIF model
        if structure.__name__[0] == "F":
            model = convert(base_model, lif_lambda, FVGG).cuda()
        else:
            model = convert(base_model, lif_lambda, VGG).cuda()
        loss, acc, fr = test(test_loader, model, criterion, t_step, 300)
        logs.append([name, num, acc, loss, *fr])
        del model
        torch.cuda.empty_cache()
    del base_model
    torch.cuda.empty_cache()
layer_num = ["FR layer " + str(x + 1).zfill(2) for x in range(len(fr))]
columns = ["model", "num", "Test ACC", "Test Loss", *layer_num]
df = pd.DataFrame(logs, columns=columns)
df.to_csv('out.csv')
