from torch import nn

from spikingjelly.activation_based import layer, neuron, surrogate

from .model_utils import batch_norm_2d, batch_norm_2d1


class SpikingPlainBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=7, padding=3, step_mode='m', backend='cupy', args=None):
        super().__init__()
        self.T = args.T
        cardinality = 1
        self.plain_function = nn.Sequential(
            layer.Conv2d(in_channels,
                         out_channels,
                         kernel_size=kernel_size,
                         stride=1,
                         padding=padding,
                         groups=cardinality,
                         bias=True,
                         step_mode=step_mode),
            batch_norm_2d(out_channels),
            neuron.IFNode(step_mode=step_mode, backend=backend),
        )

        self.out_function = nn.Sequential(
            layer.Conv2d(out_channels,
                         out_channels,
                         kernel_size=3,
                         stride=1,
                         padding=1,
                         groups=cardinality,
                         bias=False,
                         step_mode=step_mode),
            batch_norm_2d1(out_channels),
        )

    def forward(self, x):
        assert x.shape[0] == self.T

        down = self.plain_function(x)

        return down, self.out_function(down)


class SpikingDownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=7, padding=3, stride=2, step_mode='m', backend='cupy',
                 out=True, args=None):
        super().__init__()
        self.T = args.T
        self.out = out
        cardinality = 1
        self.down_function = nn.Sequential(
            layer.MaxPool2d(2, 2, 0, step_mode=step_mode) if stride == 2 else nn.Identity(),
            layer.Conv2d(in_channels,
                         out_channels,
                         kernel_size=kernel_size,
                         stride=1,
                         padding=padding,
                         groups=cardinality,
                         bias=False,
                         step_mode=step_mode),
            batch_norm_2d(out_channels),
            neuron.IFNode(step_mode=step_mode, backend=backend),
        )

        if self.out:
            self.out_function = nn.Sequential(
                layer.Conv2d(out_channels,
                             out_channels,
                             kernel_size=1,
                             stride=1,
                             padding=0,
                             groups=cardinality,
                             bias=False,
                             step_mode=step_mode),
                batch_norm_2d1(out_channels),
            )
        else:
            self.out_function = nn.Sequential()

    def forward(self, x):
        assert x.shape[0] == self.T

        down = self.down_function(x)

        if self.out:
            return down, self.out_function(down)

        return down

class OrigResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1, step_mode='m', backend='cupy',
                 act=True):
        super().__init__()
        self.residual_function = nn.Sequential(
            layer.Conv2d(in_channels,
                         out_channels,
                         kernel_size=kernel_size,
                         stride=stride,
                         padding=padding,
                         bias=False,
                         step_mode=step_mode),
            batch_norm_2d(out_channels),
            neuron.IFNode(step_mode=step_mode, backend=backend),
            layer.Conv2d(out_channels,
                         out_channels,
                         kernel_size=kernel_size,
                         padding=padding,
                         bias=False,
                         step_mode=step_mode),
            batch_norm_2d1(out_channels),
            # PruningCell(out_channels, attention="CSA", c_ratio=8),
        )

        if stride == 1:
            self.shortcut = nn.Sequential()
        else:
            self.shortcut = nn.Sequential(
                layer.Conv2d(in_channels,
                             out_channels,
                             kernel_size=1,
                             stride=stride,
                             bias=False,
                             step_mode=step_mode),
                batch_norm_2d(out_channels),
            )

        if act:
            self.activate = neuron.IFNode(step_mode=step_mode, backend=backend)
        else:
            self.activate = nn.Identity()

    def forward(self, x):
        return self.activate(self.residual_function(x) + self.shortcut(x))
