from AbstractModels.SpikingConvolutionModel import SpikingConvolutionModel

import torch
from torch import nn

from SNN.Layers import ITLIF, LIFParameters

from norse.torch import SequentialState

from SNN.util import Lift
from SNN.Normalization import TDBN3D

class ITLIFLeNet5(SpikingConvolutionModel):
    """
    A Spiking Convolutional LeNet-5 with LIF dynamics

    Arguments:
        num_classes (int): Number of classes
        method (str): Threshold method
    """
    def __init__(
            self, 
            num_classes: int = 10, 
            method: str = "super", 
            dtype: torch.dtype = torch.float,
            *args,
            **kwargs
        ):
        super(ITLIFLeNet5, self).__init__(
            num_classes=num_classes,
            *args,
            **kwargs
        )
        self.name = f'{self.__class__.__name__}'
        self.dtype = dtype

        self.params = LIFParameters(
            method=method,
            tau_syn_inv=torch.as_tensor(0.2),
            v_th=torch.as_tensor(0.5),
            v_reset=torch.as_tensor(0.0),
            alpha=torch.as_tensor(1.0)
        )

        self.model = SequentialState(
            Lift(nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)),
            TDBN3D(6, self.params.v_th),
            Lift(ITLIF(p=self.params, config=self.config)),
            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),


            Lift(nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)),
            TDBN3D(16, self.params.v_th),
            Lift(ITLIF(p=self.params, config=self.config)),
            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Flatten()),

            Lift(nn.Linear(256, 120)),
            Lift(ITLIF(p=self.params, config=self.config)),

            Lift(nn.Linear(120, num_classes))
        )

    def forward(self, x: torch.Tensor):
        return self.model(x)[0]
    
    def get_params(self) -> dict:
        return {
            "tau_syn_inv": self.params.tau_syn_inv.item(),
            "tau_mem_inv": self.params.tau_mem_inv.item(),
            "v_leak": self.params.v_leak.item(),
            "v_reset": self.params.v_reset.item(),
            "v_th": self.params.v_th.item(),
            "method": self.params.method,
            "alpha": self.params.alpha.item(),
        }