from typing import Callable, Optional
from recordclass import RecordClass

from ....Layers.Neuron import Neuron
from ....Layers.NeuronConfig import NeuronConfig

from ....util import Lift
from ....Normalization import SNNNorm3D

from .util import conv1x1, conv3x3

import torch
from torch import nn

from math import sqrt

class BasicBlock(nn.Module):
    """
        Code taken and modified from
            - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
    """
    
    expansion: int = 1
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
    
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")

        self.model = nn.Sequential(
            Lift(conv3x3(in_channels, out_channels, stride)),
            Lift(norm_layer(out_channels)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(out_channels, v_th=params.v_th),
            Lift(neuron(params, config)),
            Lift(conv3x3(out_channels, out_channels)),
            Lift(norm_layer(out_channels * self.expansion)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(out_channels * self.expansion, v_th=params.v_th, eta=1/sqrt(2)),
        )

        self.stride = stride
        self.shortcut = downsample if downsample is not None else norm_layer(out_channels, v_th=params.v_th, eta=1/sqrt(2))

        self.residual = Lift(neuron(params, config))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        basic_block = self.model(x)
        shortcut = self.shortcut(x)

        out = basic_block + shortcut
        out = self.residual(out)

        return out
    
class ZhengBlock(nn.Module):
    """
    Basic Block structure described in the work
    "Going Deeper With Directly-Trained Larger Spiking Neural Networks"
    """
    expansion: int = 1
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
    
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        
        self.model = nn.Sequential(
            Lift(conv3x3(in_channels, out_channels, stride)),
            Lift(norm_layer(out_channels)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(out_channels, v_th=params.v_th),
            Lift(neuron(params, config)),
            Lift(conv3x3(out_channels, out_channels)),
            Lift(norm_layer(out_channels)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(out_channels, v_th=params.v_th),
        )

        self.stride = stride
        self.shortcut = downsample if downsample is not None else nn.Identity()

        self.residual = Lift(neuron(params, config))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        basic_block = self.model(x)
        shortcut = self.shortcut(x)

        out = basic_block + shortcut
        out = self.residual(out)

        return out
    
class MSBasicBlock(nn.Module):
    """
    MS-ResNet Basic block as described in the work
    [1] Y. Hu, L. Deng, Y. Wu, M. Yao, and G. Li, “Advancing Spiking Neural Networks Toward Deep Residual Learning,” 
    IEEE Trans. Neural Netw. Learning Syst., pp. 1-15, 2024, doi: 10.1109/TNNLS.2024.3355393.
    """
    expansion: int = 1
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
    
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")

        self.model = nn.Sequential(
            Lift(neuron(params, config)),
            Lift(conv3x3(in_channels, out_channels, stride)),
            Lift(norm_layer(out_channels)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(out_channels, v_th=params.v_th),
            Lift(neuron(params, config)),
            Lift(conv3x3(out_channels, out_channels * self.expansion)),
            Lift(norm_layer(out_channels * self.expansion)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(out_channels * self.expansion, v_th=params.v_th, eta=1/sqrt(2)),
        )

        self.stride = stride

        self.shortcut = nn.Sequential(
            downsample if downsample is not None else nn.Identity()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        basic_block = self.model(x)
        shortcut = self.shortcut(x)
        
        return basic_block + shortcut
    
class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion: int = 4

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        width = int(out_channels * (base_width / 64.0)) * groups

        self.model = nn.Sequential(
            Lift(conv1x1(in_channels, width)),
            Lift(norm_layer(width)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(width, v_th=params.v_th),
            Lift(neuron(params, config)),

            Lift(conv3x3(width, width, stride, groups, dilation)),
            Lift(norm_layer(width)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(width, v_th=params.v_th),
            Lift(neuron(params, config)),

            Lift(conv1x1(width, out_channels * self.expansion)),
            Lift(norm_layer(out_channels * self.expansion)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(out_channels * self.expansion, v_th=params.v_th),
        )

        self.shortcut = nn.Sequential(
            downsample if downsample is not None else nn.Identity()
        )

        self.residual = nn.Sequential(
            Lift(neuron(params, config))
        )

        self.stride = stride

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
        basic_block = self.model(x)
        shortcut = self.shortcut(x)
        
        out = basic_block + shortcut
        out = self.residual(out)

        return out

class MSBottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion: int = 4

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        width = int(out_channels * (base_width / 64.0)) * groups

        self.model = nn.Sequential(
            Lift(neuron(params, config)),
            Lift(conv1x1(in_channels, width)),
            Lift(norm_layer(width)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(width, v_th=params.v_th),

            Lift(neuron(params, config)),
            Lift(conv3x3(width, width, stride, groups, dilation)),
            Lift(norm_layer(width)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(width, v_th=params.v_th),

            Lift(neuron(params, config)),
            Lift(conv1x1(width, out_channels * self.expansion)),
            Lift(norm_layer(out_channels * self.expansion)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(out_channels * self.expansion, v_th=params.v_th, eta=1/sqrt(2)),
        )

        self.shortcut = nn.Sequential(
            downsample if downsample is not None else nn.Identity()
        )

        self.stride = stride

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
        basic_block = self.model(x)
        shortcut = self.shortcut(x)

        if isinstance(shortcut, tuple):
            shortcut = shortcut[0]

        return basic_block + shortcut
    
