from AbstractModels.SpikingConvolutionModel import SpikingConvolutionModel
from .ResNet import resnet, ZhengBackbone, ZhengClassifier, ZhengBlock

from SNN.Layers import ITLIF, LIFParameters
from SNN.Normalization import TDBN3D

import torch

class ITLIFResNet19(SpikingConvolutionModel):
    def __init__(
            self, 
            num_classes: int = 10, 
            method: str = "super", 
            dtype: torch.dtype = torch.float,
            *args,
            **kwargs
        ):
        super(ITLIFResNet19, self).__init__(
            num_classes=num_classes,
            *args,
            **kwargs
        )
        self.name = f'{self.__class__.__name__}'
        self.dtype = dtype

        self.params = LIFParameters(
            method=method,
            tau_mem_inv=torch.as_tensor(0.25),
            v_th=torch.as_tensor(0.5),
            v_reset=torch.as_tensor(0.0),
            alpha=torch.as_tensor(1.0)
        )
        
        self.model = resnet(
            backbone=ZhengBackbone,
            block=ZhengBlock,
            classifier=ZhengClassifier,
            num_classes=num_classes,
            channels=[128, 256, 512, None],
            layers=[3, 3, 2, None],
            neuron=ITLIF,
            params=self.params,
            config=self.config,
            norm_layer=TDBN3D
        )

        fake_data = torch.randn(1, 1, 3, 32, 32)

        self.model(fake_data)

        total_params = 0
        for p in self.model.parameters():
            total_params += p.numel()
        print(f"Total number of parameters: {total_params}")

    def forward(self, x):
        return self.model(x)

    def get_params(self) -> dict:
        return {
            "tau_syn_inv": self.params.tau_syn_inv.item(),
            "tau_mem_inv": self.params.tau_mem_inv.item(),
            "v_leak": self.params.v_leak.item(),
            "v_reset": self.params.v_reset.item(),
            "v_th": self.params.v_th.item(),
            "method": self.params.method,
            "alpha": self.params.alpha.item(),
        }