from typing import Callable, Optional
from recordclass import RecordClass

from ....Layers.Neuron import Neuron
from ....Layers.NeuronConfig import NeuronConfig
from ....util import Lift
from ....Normalization import SNNNorm3D

import torch
from torch import nn

class StandardBackbone(nn.Module):
    def __init__(
        self,
        in_channels: int,
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(StandardBackbone, self).__init__()
        
        self.model = nn.Sequential(
            Lift(nn.Conv2d(3, in_channels, kernel_size=7, stride=2, padding=3, bias=False)),
            Lift(norm_layer(in_channels)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(in_channels, v_th=params.v_th),
            Lift(neuron(params, config), return_state=False),
            Lift(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        )

    def forward(self, x) -> torch.Tensor:
        return self.model(x)
    
class CifarBackbone(nn.Module):
    def __init__(
        self,
        in_channels: int,
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(CifarBackbone, self).__init__()
        
        self.model = nn.Sequential(
            Lift(nn.Conv2d(3, in_channels, kernel_size=3, bias=False)),
            Lift(norm_layer(in_channels)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(in_channels, v_th=params.v_th)
        )

    def forward(self, x) -> torch.Tensor:
        return self.model(x)
    
class MSBackbone(nn.Module):
    def __init__(
        self,
        in_channels: int,
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(MSBackbone, self).__init__()
        
        self.model = nn.Sequential(
            Lift(nn.Conv2d(3, in_channels, kernel_size=3, stride=2, padding=1, bias=False)),
            Lift(norm_layer(in_channels)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(in_channels, v_th=params.v_th),
            Lift(nn.AvgPool2d(kernel_size=3, stride=2, padding=1)),
        )

    def forward(self, x) -> torch.Tensor:
        return self.model(x)
    
class ZhengBackbone(nn.Module):
    """Backbone described in Going Deeper With Directly-Trained Larger Spiking Neural Networks
        https://arxiv.org/abs/2011.05280
    """
    def __init__(
        self,
        in_channels: int,
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(ZhengBackbone, self).__init__()

        self.model = nn.Sequential(
            Lift(nn.Conv2d(3, in_channels, kernel_size=3, stride=1, padding=1, bias=False)),
            Lift(norm_layer(in_channels)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(in_channels, v_th=params.v_th),
            Lift(neuron(params, config), return_state=False)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)
    
class ZhengStandardBackbone(nn.Module):
    """Backbone described in Going Deeper With Directly-Trained Larger Spiking Neural Networks
        https://arxiv.org/abs/2011.05280
    """
    def __init__(
        self,
        in_channels: int,
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(ZhengStandardBackbone, self).__init__()

        self.model = nn.Sequential(
            Lift(nn.Conv2d(3, in_channels, kernel_size=7, stride=2, padding=1, bias=False)),
            Lift(norm_layer(in_channels)) if not issubclass(norm_layer, SNNNorm3D) else norm_layer(in_channels, v_th=params.v_th),
            Lift(neuron(params, config), return_state=False)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)