from typing import Callable, Optional
from recordclass import RecordClass

from ....Layers.Neuron import Neuron
from ....Layers.NeuronConfig import NeuronConfig
from ....util import Lift

import torch
from torch import nn

class StandardBackbone(nn.Module):
    def __init__(
        self,
        in_channels: int,
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(StandardBackbone, self).__init__()
        
        self.model = nn.Sequential(
            Lift(nn.Conv2d(3, in_channels, kernel_size=7, stride=2, padding=3, bias=False)),
            norm_layer(in_channels, v_th=params.v_th),
            Lift(neuron(params, config), return_state=False),
            Lift(nn.AvgPool2d(kernel_size=3, stride=2, padding=1)),
        )

    def forward(self, x) -> torch.Tensor:
        return self.model(x)
    
class CifarBackbone(nn.Module):
    def __init__(
        self,
        in_channels: int,
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(CifarBackbone, self).__init__()
        
        self.model = nn.Sequential(
            Lift(nn.Conv2d(3, in_channels, kernel_size=3, bias=False)),
            norm_layer(in_channels, v_th=params.v_th)
        )

    def forward(self, x) -> torch.Tensor:
        return self.model(x)
    
class MSBackbone(nn.Module):
    def __init__(
        self,
        in_channels: int,
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(MSBackbone, self).__init__()
        
        self.model = nn.Sequential(
            Lift(nn.Conv2d(3, in_channels, kernel_size=3, stride=2, padding=1, bias=False)),
            norm_layer(in_channels, v_th=params.v_th),
            Lift(nn.AvgPool2d(kernel_size=3, stride=2, padding=1)),
        )

    def forward(self, x) -> torch.Tensor:
        return self.model(x)
    
class ZhengBackbone(nn.Module):
    """Backbone described in Going Deeper With Directly-Trained Larger Spiking Neural Networks
        https://arxiv.org/abs/2011.05280
    """
    def __init__(
        self,
        in_channels: int,
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(ZhengBackbone, self).__init__()
        
        self.model = nn.Sequential(
            Lift(nn.Conv2d(3, in_channels, kernel_size=3, stride=1, padding=1, bias=False)),
            norm_layer(in_channels, v_th=params.v_th),
            Lift(neuron(params, config), return_state=False)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)
    
class ZhengNeuromorphicBackbone(nn.Module):
    """Backbone described in Going Deeper With Directly-Trained Larger Spiking Neural Networks
        https://arxiv.org/abs/2011.05280
    """
    def __init__(
        self,
        in_channels: int,
        neuron: Neuron,
        params: RecordClass,
        config: NeuronConfig,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(ZhengNeuromorphicBackbone, self).__init__()
        
        self.model = nn.Sequential(
            Lift(nn.Conv2d(2, in_channels, kernel_size=3, stride=1, padding=1, bias=False)),
            norm_layer(in_channels, v_th=params.v_th),
            Lift(neuron(params, config), return_state=False)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)