from AbstractModels.SpikingModel import SpikingModel

import torch
from torch import nn

from SNN.Layers import LNM, LNMParameters

from norse.torch import SequentialState

from SNN.util import Lift
from SNN.Normalization import TDBN3D

import math

class LIFVGGSNN(SpikingModel):
    """
    A Spiking Convolutional Alexnet with LIF dynamics

    Arguments:
        num_classes (int): Number of classes
        method (str): Threshold method
    """
    def __init__(
            self, 
            num_classes: int = 10, 
            method: str = "rectangle", 
            dtype: torch.dtype = torch.float,
            *args,
            **kwargs
        ):
        
        super(LIFVGGSNN, self).__init__(
            num_classes=num_classes,
            *args,
            **kwargs
        )
        self.name = f'{self.__class__.__name__}'
        
        self.dtype = dtype

        self.params = LNMParameters(
            method=method,
            tau_mem_inv=torch.as_tensor(0.25),
            v_th=torch.as_tensor(0.5),
            v_reset=torch.as_tensor(0.0),
            alpha=torch.as_tensor(1.0)
        )

        neuron = LNM
        bias = False

        factor = 1
        c1 = 64 // factor
        c2 = 128 // factor
        c3 = 256 // factor
        c4 = 512 // factor

        fc_dropout = 0.6

        self.model = SequentialState(
            Lift(nn.Conv2d(in_channels=2, out_channels=c1, kernel_size=3, padding=1, bias=bias)),
            TDBN3D(c1, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.Conv2d(in_channels=c1, out_channels=c2, kernel_size=3, padding=1, bias=bias)),
            TDBN3D(c2, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Conv2d(in_channels=c2, out_channels=c3, kernel_size=3, padding=1, bias=bias)),
            TDBN3D(c3, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.Conv2d(in_channels=c3, out_channels=c3, kernel_size=3, padding=1, bias=bias)),
            TDBN3D(c3, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Conv2d(in_channels=c3, out_channels=c4, kernel_size=3, padding=1, bias=bias)),
            TDBN3D(c4, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.Conv2d(in_channels=c4, out_channels=c4, kernel_size=3, padding=1, bias=bias)),
            TDBN3D(c4, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),
            
            Lift(nn.Conv2d(in_channels=c4, out_channels=c4, kernel_size=3, padding=1, bias=bias)),
            TDBN3D(c4, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.Conv2d(in_channels=c4, out_channels=c4, kernel_size=3, padding=1, bias=bias)),
            TDBN3D(c4, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Flatten()),

            Lift(nn.Dropout(p=fc_dropout)),

            Lift(nn.Linear(3*3*c4, num_classes)),
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.decoder(self.model(self.encoder(x))[0])

    def get_params(self) -> dict:
        return {
            "tau_syn_inv": self.params.tau_syn_inv.item(),
            "tau_mem_inv": self.params.tau_mem_inv.item(),
            "v_leak": self.params.v_leak.item(),
            "v_reset": self.params.v_reset.item(),
            "v_th": self.params.v_th.item(),
            "method": self.params.method,
            "alpha": self.params.alpha.item(),
        }