from AbstractModels.SpikingConvolutionModel import SpikingConvolutionModel

import torch
from torch import nn

from SNN.Layers import ITQIF, QIFParameters

from norse.torch import SequentialState

from SNN.util import Lift
from SNN.Normalization import TDBN3D

import json

class ITQIFDVSGestureNet(SpikingConvolutionModel):
    """
    A Spiking Convolutional Alexnet with LIF dynamics

    Arguments:
        num_classes (int): Number of classes
        method (str): Threshold method
    """
    def __init__(
            self, 
            num_classes: int = 10, 
            method: str = "super", 
            dtype: torch.dtype = torch.float,
            *args,
            **kwargs
        ):
        
        super(ITQIFDVSGestureNet, self).__init__(
            num_classes=num_classes,
            *args,
            **kwargs
        )
        self.name = f'{self.__class__.__name__}'
        
        self.dtype = dtype

        # self.params = QIFParameters(
        #     method=method,
        #     v_th=torch.as_tensor(0.5),
        #     v_c=torch.as_tensor(0.5),
        #     v_reset=torch.as_tensor(0.0),
        #     a=torch.torch.as_tensor(0.25),
        # )

        file_name = "./Graphs/Robustness/qlif_combinations.json"
        index = 0
        data = None
        with open(file_name, "r") as file:
            data = json.load(file)
            index = data[-1]['cur_num']

            self.params = QIFParameters(
                method=method,
                v_c=torch.as_tensor(data[index]['u_c']),
                v_th=torch.as_tensor(data[index]['u_th']),
                a=torch.as_tensor(data[index]['a']),
                v_reset=torch.as_tensor(0.0),
                alpha=torch.as_tensor(1.0)
            )
            
        data[-1]['cur_num'] = index + 1
        with open(file_name, "w") as file:
            json.dump(data, file, indent=4)

        neuron = ITQIF
        bias = False

        factor = 1
        c1 = 128 // factor

        fc_dropout = 0.0

        self.model = SequentialState(
            Lift(nn.Conv2d(in_channels=2, out_channels=c1, kernel_size=3, bias=bias)),
            TDBN3D(c1, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),
            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Conv2d(in_channels=c1, out_channels=c1, kernel_size=3, bias=bias)),
            TDBN3D(c1, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),
            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Conv2d(in_channels=c1, out_channels=c1, kernel_size=3, bias=bias)),
            TDBN3D(c1, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),
            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Conv2d(in_channels=c1, out_channels=c1, kernel_size=3, bias=bias)),
            TDBN3D(c1, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),
            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Conv2d(in_channels=c1, out_channels=c1, kernel_size=3, bias=bias)),
            TDBN3D(c1, v_th=self.params.v_th),
            Lift(neuron(p=self.params, config=self.config)),
            Lift(nn.AvgPool2d(kernel_size=2, stride=2)),

            Lift(nn.Flatten()),

            Lift(nn.Dropout(p=fc_dropout)),

            Lift(nn.Linear(2*2*c1, 128)),

            Lift(neuron(p=self.params, config=self.config)),

            Lift(nn.Dropout(p=fc_dropout)),

            Lift(nn.Linear(128, num_classes))
        )

    def forward(self, x: torch.Tensor):
        return self.model(x)[0]

    def get_params(self) -> dict:
        return {
            "tau_syn_inv": self.params.tau_syn_inv.item(),
            "tau_mem_inv": self.params.tau_mem_inv.item(),
            "v_leak": self.params.v_leak.item(),
            "v_c": self.params.v_c.item(),
            "a": self.params.a.item(),
            "v_reset": self.params.v_reset.item(),
            "v_th": self.params.v_th.item(),
            "method": self.params.method,
            "alpha": self.params.alpha.item(),
        }