from AbstractModels.SpikingModel import SpikingModel
from .ResNet import resnet, SEWBlock, ZhengStandardClassifier, StandardBackbone

from SNN.Layers import LNM, LNMParameters
from .TDBN3D import TDBN3D

import torch

class ITLIFResNet34(SpikingModel):
    def __init__(
            self, 
            num_classes: int = 10, 
            method: str = "super", 
            dtype: torch.dtype = torch.float,
            *args,
            **kwargs
        ):
        super(ITLIFResNet34, self).__init__(
            num_classes=num_classes,
            *args,
            **kwargs
        )
        self.name = f'{self.__class__.__name__}'
        self.dtype = dtype

        self.params = LNMParameters(
            method=method,
            tau_mem_inv=torch.as_tensor(0.25),
            v_th=torch.as_tensor(0.5),
            v_reset=torch.as_tensor(0.0),
            alpha=torch.as_tensor(0.5)
        )
        
        self.model = resnet(
            backbone=StandardBackbone,
            block=SEWBlock,
            classifier=ZhengStandardClassifier,
            num_classes=num_classes,
            channels=[64, 128, 256, 512],
            layers=[3, 4, 6, 3],
            neuron=LNM,
            params=self.params,
            config=self.config,
            norm_layer=TDBN3D
        )
        
        # with torch.no_grad():
        #     fake_data = torch.randn(1, 128, 3, 224, 224)
        #     self.model(fake_data)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.decoder(self.model(self.encoder(x)))

    def get_params(self) -> dict:
        return {
            "tau_syn_inv": self.params.tau_syn_inv.item(),
            "tau_mem_inv": self.params.tau_mem_inv.item(),
            "v_leak": self.params.v_leak.item(),
            "v_reset": self.params.v_reset.item(),
            "v_th": self.params.v_th.item(),
            "method": self.params.method,
            "alpha": self.params.alpha.item(),
        }
