from AbstractModels.SpikingConvolutionModel import SpikingConvolutionModel

import torch
from torch import nn

from SNN.Layers import ITLIF, LIFParameters

from norse.torch import SequentialState

from SNN.util import Lift
from SNN.Normalization import TDBN3D

import json

class ITLIFDVSGestureNet(SpikingConvolutionModel):
    """
    A Spiking Convolutional Alexnet with LIF dynamics

    Arguments:
        num_classes (int): Number of classes
        method (str): Threshold method
    """
    def __init__(
            self, 
            num_classes: int = 10, 
            method: str = "super", 
            dtype: torch.dtype = torch.float,
            *args,
            **kwargs
        ):
        
        super(ITLIFDVSGestureNet, self).__init__(
            num_classes=num_classes,
            *args,
            **kwargs
        )
        self.name = f'{self.__class__.__name__}'
        
        self.dtype = dtype


        file_name = "./Graphs/Robustness/lif_combinations.json"
        index = 0
        data = None
        with open(file_name, "r") as file:
            data = json.load(file)
            index = data[-1]['cur_num']

            self.params = LIFParameters(
                method=method,
                tau_mem_inv=torch.as_tensor(data[index]['beta']),
                v_th=torch.as_tensor(data[index]['u_th']),
                v_reset=torch.as_tensor(0.0),
                alpha=torch.as_tensor(1.0)
            )
        data[-1]['cur_num'] = index + 1
        with open(file_name, "w") as file:
            json.dump(data, file, indent=4)

        neuron = ITLIF

        bias = False

        factor = 1
        c1 = 128 // factor

        fc_dropout = 0.0

        self.model = SequentialState(
            Lift(nn.Conv2d(in_channels=2, out_channels=c1, kernel_size=3, bias=bias)),
            TDBN3D(c1, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),
            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Conv2d(in_channels=c1, out_channels=c1, kernel_size=3, bias=bias)),
            TDBN3D(c1, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),
            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Conv2d(in_channels=c1, out_channels=c1, kernel_size=3, bias=bias)),
            TDBN3D(c1, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),
            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Conv2d(in_channels=c1, out_channels=c1, kernel_size=3, bias=bias)),
            TDBN3D(c1, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),
            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Conv2d(in_channels=c1, out_channels=c1, kernel_size=3, bias=bias)),
            TDBN3D(c1, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),
            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Flatten()),

            Lift(nn.Dropout(p=fc_dropout)),

            Lift(nn.Linear(2*2*c1, 128)),

            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.Dropout(p=fc_dropout)),

            Lift(nn.Linear(128, num_classes))
        )

    def forward(self, x: torch.Tensor):
        return self.model(x)[0]

    def get_params(self) -> dict:
        return {
            "tau_syn_inv": self.params.tau_syn_inv.item(),
            "tau_mem_inv": self.params.tau_mem_inv.item(),
            "v_leak": self.params.v_leak.item(),
            # "v_c": self.params.v_c.item(),
            # "a": self.params.a.item(),
            "v_reset": self.params.v_reset.item(),
            "v_th": self.params.v_th.item(),
            "method": self.params.method,
            "alpha": self.params.alpha.item(),
        }