from AbstractModels.SpikingConvolutionModel import SpikingConvolutionModel

import torch
from torch import nn

from SNN.Layers import ITQIF

from SNN.Layers import QIFParameters
from norse.torch import SequentialState

from SNN.Normalization import TDBN3D
from SNN.util import Lift

class ITQIFLeNet5(SpikingConvolutionModel):
    """
    A Spiking Convolutional LeNet-5 with LIF dynamics

    Arguments:
        num_classes (int): Number of classes
        method (str): Threshold method
    """
    def __init__(
            self, 
            num_classes: int = 10, 
            method: str = "super", 
            dtype: torch.dtype = torch.float,
            *args,
            **kwargs
        ):

        super(ITQIFLeNet5, self).__init__(
            num_classes=num_classes,
            *args,
            **kwargs
        )
        self.name = f'{self.__class__.__name__}'
        self.dtype = dtype

        self.params = QIFParameters(
            method=method,
            v_th=torch.as_tensor(0.5),
            v_c=torch.as_tensor(0.5),
            a=torch.as_tensor(0.2),
            v_reset=torch.as_tensor(0.0),
            alpha=torch.as_tensor(1.0)
        )
        
        self.model = SequentialState(
            Lift(nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)),
            TDBN3D(6, v_th=self.params.v_th),
            Lift(ITQIF(p=self.params, config=self.config)),
            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)),
            TDBN3D(16, v_th=self.params.v_th),
            Lift(ITQIF(p=self.params, config=self.config)),
            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Flatten()),

            Lift(nn.Linear(256, 120)),
            Lift(ITQIF(p=self.params, config=self.config)),

            Lift(nn.Linear(120, num_classes))
        )

    def forward(self, x: torch.Tensor):
        return self.model(x)[0]
    
    def get_params(self) -> dict:
        return {
            "tau_syn_inv": self.params.tau_syn_inv.item(),
            "tau_mem_inv": self.params.tau_mem_inv.item(),
            "v_leak": self.params.v_leak.item(),
            "v_reset": self.params.v_reset.item(),
            "v_th": self.params.v_th.item(),
            "v_c": self.params.v_c.item(),
            "a": self.params.a.item(),
            "method": self.params.method,
            "alpha": self.params.alpha.item(),
        }