from AbstractModels.SpikingConvolutionModel import SpikingConvolutionModel

import torch
from torch import nn

from SNN.Layers import ITLIF, LIFParameters

from norse.torch import SequentialState

from SNN.util import Lift
from SNN.Normalization import TDBN3D

import math

class ITLIFVGG11(SpikingConvolutionModel):
    """
    A Spiking Convolutional Alexnet with LIF dynamics

    Arguments:
        num_classes (int): Number of classes
        method (str): Threshold method
    """
    def __init__(
            self, 
            num_classes: int = 10, 
            method: str = "super", 
            dtype: torch.dtype = torch.float,
            *args,
            **kwargs
        ):
        
        super(ITLIFVGG11, self).__init__(
            num_classes=num_classes,
            *args,
            **kwargs
        )
        self.name = f'{self.__class__.__name__}'
        
        self.dtype = dtype

        self.params = LIFParameters(
            method=method,
            tau_mem_inv=torch.as_tensor(0.25),
            v_th=torch.as_tensor(0.5),
            v_reset=torch.as_tensor(0.0),
            alpha=torch.as_tensor(1.0)
        )

        neuron = ITLIF
        bias = False

        factor = 1
        c1 = 64 // factor
        c2 = 128 // factor
        c3 = 256 // factor
        c4 = 512 // factor

        fc_dropout = 0.6

        self.model = SequentialState(
            Lift(nn.Conv2d(in_channels=2, out_channels=c1, kernel_size=3, padding=1, bias=bias)),
            TDBN3D(c1, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Conv2d(in_channels=c1, out_channels=c2, kernel_size=3, padding=1, bias=bias)),
            TDBN3D(c2, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Conv2d(in_channels=c2, out_channels=c3, kernel_size=3, padding=1, bias=bias)),
            TDBN3D(c3, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.Conv2d(in_channels=c3, out_channels=c3, kernel_size=3, padding=1, bias=bias)),
            TDBN3D(c3, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Conv2d(in_channels=c3, out_channels=c4, kernel_size=3, padding=1, bias=bias)),
            TDBN3D(c4, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.Conv2d(in_channels=c4, out_channels=c4, kernel_size=3, padding=1, bias=bias)),
            TDBN3D(c4, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),
            
            Lift(nn.Conv2d(in_channels=c4, out_channels=c4, kernel_size=3, padding=1, bias=bias)),
            TDBN3D(c4, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.Conv2d(in_channels=c4, out_channels=c4, kernel_size=3, padding=1, bias=bias)),
            TDBN3D(c4, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Flatten()),

            Lift(nn.Dropout(p=fc_dropout)),

            Lift(nn.Linear(3*3*c4, num_classes))
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, x: torch.Tensor):
        return self.model(x)[0]

    def get_params(self) -> dict:
        return {
            "tau_syn_inv": self.params.tau_syn_inv.item(),
            "tau_mem_inv": self.params.tau_mem_inv.item(),
            "v_leak": self.params.v_leak.item(),
            "v_c": self.params.v_c.item(),
            "a": self.params.a.item(),
            "v_reset": self.params.v_reset.item(),
            "v_th": self.params.v_th.item(),
            "method": self.params.method,
            "alpha": self.params.alpha.item(),
        }