from AbstractModels.SpikingModel import SpikingModel
from .ResNet import resnet, ZhengBackbone, ZhengStandardClassifier, SEWBlock

from SNN.Layers import LNM, LNMParameters
from SNN.Normalization import TDBN3D

import torch

class ITLIFResNet19(SpikingModel):
    def __init__(
            self, 
            num_classes: int = 10, 
            method: str = "rectangle", 
            dtype: torch.dtype = torch.float,
            *args,
            **kwargs
        ):
        super(ITLIFResNet19, self).__init__(
            num_classes=num_classes,
            *args,
            **kwargs
        )
        self.name = f'{self.__class__.__name__}'
        self.dtype = dtype

        self.params = LNMParameters(
            method=method,
            tau_mem_inv=torch.as_tensor(0.25),
            v_th=torch.as_tensor(0.5),
            v_reset=torch.as_tensor(0.0)
        )
        
        self.model = resnet(
            backbone=ZhengBackbone,
            block=SEWBlock,
            classifier=ZhengStandardClassifier,
            num_classes=num_classes,
            channels=[128, 256, 512, None],
            layers=[3, 3, 2, None],
            neuron=LNM,
            params=self.params,
            config=self.config,
            norm_layer=TDBN3D
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.decoder(self.model(self.encoder(x)))

    def get_params(self) -> dict:
        return {
            "tau_syn_inv": self.params.tau_syn_inv.item(),
            "tau_mem_inv": self.params.tau_mem_inv.item(),
            "v_leak": self.params.v_leak.item(),
            "v_reset": self.params.v_reset.item(),
            "v_th": self.params.v_th.item(),
            "method": self.params.method,
            "alpha": self.params.alpha.item(),
        }