from typing import Type, List, Optional, Callable, Any
from recordclass import RecordClass

from ....Layers.Neuron import Neuron
from ....Layers.NeuronConfig import NeuronConfig
from ....util import Lift
from ....Normalization import TDBN3D, SNNNorm3D

from .ResNetBlocks import BasicBlock, ZhengBlock, MSBasicBlock, Bottleneck, MSBottleneck, SEWBlock
from .ResNetBackbones import StandardBackbone, CifarBackbone, ZhengBackbone, ZhengNeuromorphicBackbone, MSBackbone
from .ResNetClassifiers import ZhengClassifier, ZhengStandardClassifier, StandardClassifier
from .util import conv1x1

import torch
from torch import nn

BasicBlocks = Type[BasicBlock | ZhengBlock | SEWBlock | MSBasicBlock | Bottleneck | MSBottleneck]
Backbones = Type[StandardBackbone | CifarBackbone | ZhengBackbone | ZhengNeuromorphicBackbone | MSBackbone]
Classifiers = Type[ZhengClassifier | ZhengStandardClassifier | StandardClassifier]

class ResNet(nn.Module):
    """
        Code taken and modified from
            - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
    """
    def __init__(
        self,
        backbone: Backbones,
        block: BasicBlocks,
        classifier: Classifiers,
        layers: List[int],
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        num_classes: int = 1000,
        in_channels: int = 64,
        groups: int = 1,
        width_per_group: int = 64,
        channels: list = [64, 128, 256, 512],
        replace_stride_with_dilation: Optional[List[bool]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.in_channels = in_channels
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                f"or a 3-element tuple, got {replace_stride_with_dilation}"
            )
        
        self.groups = groups
        self.base_width = width_per_group
        
        self.model = nn.Sequential(
            backbone(in_channels=self.in_channels, neuron=neuron, params=params, config=config, norm_layer=norm_layer),
            self._make_layer(block, channels[0], layers[0], neuron=neuron, params=params, config=config),
            self._make_layer(block, channels[1], layers[1], stride=2, dilate=replace_stride_with_dilation[0], neuron=neuron, params=params, config=config),
            self._make_layer(block, channels[2], layers[2], stride=2, dilate=replace_stride_with_dilation[1], neuron=neuron, params=params, config=config),
            self._make_layer(block, channels[3], layers[3], stride=2, dilate=replace_stride_with_dilation[2], neuron=neuron, params=params, config=config),
            Lift(nn.AdaptiveAvgPool2d((1, 1))),
            Lift(nn.Flatten()),
            classifier(num_classes=num_classes, neuron=neuron, params=params, config=config, norm_layer=norm_layer)
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)) and not isinstance(m, TDBN3D):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(
        self,
        block: BasicBlocks,
        out_channels: int | None,
        blocks: int,
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        stride: int = 1,
        dilate: bool = False,
    ) -> nn.Sequential:
        if out_channels is None:
            return nn.Identity()
        norm_layer = self._norm_layer

        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1

        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                Lift(conv1x1(self.in_channels, out_channels * block.expansion, stride)),
                norm_layer(out_channels * block.expansion, v_th=params.v_th),
                Lift(neuron(params, config)) if isinstance(block, SEWBlock) else nn.Identity(),
            )

        layers = []
        layers.append(
            block(
                in_channels=self.in_channels, 
                out_channels=out_channels, 
                stride=stride, 
                downsample=downsample, 
                groups=self.groups, 
                base_width=self.base_width, 
                dilation=previous_dilation, 
                norm_layer=norm_layer,
                neuron=neuron,
                params=params,
                config=config
            )
        )
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    in_channels=self.in_channels,
                    out_channels=out_channels,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                    neuron=neuron,
                    params=params,
                    config=config
                )
            )

        return nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)
    
def _resnet(
    block: BasicBlocks,
    backbone: Backbones,
    classifier: Classifiers,
    layers: List[int],
    channels: List[int],
    in_channels: int,
    num_classes: int,
    neuron: Neuron,
    params: RecordClass,
    config: NeuronConfig,
    norm_layer: Optional[Callable[..., nn.Module]] = None,
    **kwargs: Any,
) -> ResNet:
    return ResNet(
        block=block, 
        backbone=backbone,
        classifier=classifier,
        layers=layers, 
        channels=channels,
        in_channels=in_channels,
        num_classes=num_classes,
        neuron=neuron,
        params=params,
        config=config,
        norm_layer=norm_layer,
        **kwargs
    )

def resnet(
    num_classes: int,
    neuron: Neuron, 
    params: RecordClass, 
    config: NeuronConfig,
    layers: List[int] = [2, 2, 2, 2],
    channels: List[int | None] = [64, 128, 256, 512],
    in_channels: int = 64,
    block: BasicBlocks = BasicBlock,
    backbone: Backbones = StandardBackbone,
    classifier: Classifiers = StandardClassifier,
    norm_layer: Optional[Callable[..., nn.Module]] = None
) -> ResNet:
    return _resnet(
        block=block, 
        backbone=backbone,
        classifier=classifier,
        layers=layers, 
        channels=channels,
        in_channels=in_channels,
        num_classes=num_classes, 
        neuron=neuron,
        params=params,
        config=config,
        norm_layer=norm_layer
    )

def resnet18(
    num_classes: int,
    neuron: Neuron, 
    params: RecordClass, 
    config: NeuronConfig,
    channels: List[int | None] = [64, 128, 256, 512],
    block: BasicBlocks = BasicBlock,
    backbone: Backbones = StandardBackbone,
    classifier: Classifiers = StandardClassifier,
    norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> ResNet:
    return _resnet(
        block=block, 
        backbone=backbone,
        classifier=classifier,
        layers=[2, 2, 2, 2], 
        channels=channels,
        num_classes=num_classes, 
        neuron=neuron,
        params=params,
        config=config,
        norm_layer=norm_layer
    )

def resnext18(
    num_classes: int, 
    neuron: Neuron, 
    params: RecordClass,
    config: NeuronConfig, 
    channels: List[int | None] = [64, 128, 256, 512],
    block: BasicBlocks = BasicBlock,
    backbone: Backbones = StandardBackbone,
    norm_layer: Optional[Callable[..., nn.Module]] = None
) -> ResNet:
    return _resnet(
        block=block,
        backbone=backbone,
        classifier=StandardClassifier,
        layers=[2, 2, 2, 2],
        channels=channels,
        num_classes=num_classes, 
        neuron=neuron,
        params=params,
        config=config,
        norm_layer=norm_layer,
        groups=32,
        width_per_group=4
    )