from AbstractModels.SpikingConvolutionModel import SpikingConvolutionModel
from .ResNet import resnet, ZhengBlock, ZhengBackbone, ZhengClassifier, ZhengStandardBackbone, ZhengStandardClassifier

from SNN.Layers import ITQIF
from SNN.Layers.QIF import QIFParameters
from SNN.Normalization import TDBN3D

import torch

class ITQIFResNet34(SpikingConvolutionModel):
    def __init__(
            self, 
            num_classes: int = 10, 
            method: str = "super", 
            dtype: torch.dtype = torch.float,
            *args,
            **kwargs
        ):
        super(ITQIFResNet34, self).__init__(
            num_classes=num_classes,
            *args,
            **kwargs
        )
        self.name = f'{self.__class__.__name__}'
        self.dtype = dtype

        self.params = QIFParameters(
            method=method,
            v_th=torch.as_tensor(0.5),
            v_c=torch.as_tensor(0.5),
            v_reset=torch.as_tensor(0.0),
            a=torch.torch.as_tensor(0.25),
            alpha=torch.as_tensor(1.0),
        )
        
        self.model = resnet(
            backbone=ZhengStandardBackbone,
            block=ZhengBlock,
            classifier=ZhengStandardClassifier,
            num_classes=num_classes,
            channels=[64, 128, 256, 512],
            layers=[3, 4, 6, 3],
            neuron=ITQIF,
            params=self.params,
            config=self.config,
            norm_layer=TDBN3D
        )

    def forward(self, x):
        return self.model(x)

    def get_params(self) -> dict:
        return {
            "tau_syn_inv": self.params.tau_syn_inv.item(),
            "tau_mem_inv": self.params.tau_mem_inv.item(),
            "v_leak": self.params.v_leak.item(),
            "v_reset": self.params.v_reset.item(),
            "v_th": self.params.v_th.item(),
            "method": self.params.method,
            "alpha": self.params.alpha.item(),
        }