import math
from .layer import *


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride, **kwargs_spikes):
        super(BasicBlock, self).__init__()
        self.kwargs_spikes = kwargs_spikes
        self.nb_steps = kwargs_spikes['nb_steps']
        self.conv1 = tdLayer(nn.Conv2d(in_planes, planes, 3, stride=stride, padding=1, bias=False), self.nb_steps)
        self.bn1 = tdBatchNorm(nn.BatchNorm2d(planes), 1)
        self.conv2 = tdLayer(nn.Conv2d(planes, planes, 3, 1, 1, bias=False), self.nb_steps)
        self.bn2 = tdBatchNorm(nn.BatchNorm2d(planes), alpha=0.5 ** 0.5)

        self.stride = stride
        self.spike1 = LIF(**kwargs_spikes)
        self.spike2 = LIF(**kwargs_spikes)

        self.downsample = tdBatchNorm(nn.BatchNorm2d(planes), alpha=0.5 ** 0.5)

        if stride != 1 or in_planes != planes * BasicBlock.expansion:
            self.downsample = nn.Sequential(
                tdLayer(nn.Conv2d(in_planes, planes * BasicBlock.expansion, 1, stride, bias=False), self.nb_steps),
                tdBatchNorm(nn.BatchNorm2d(planes * BasicBlock.expansion), alpha=1 / math.sqrt(2.))
            )

    def forward(self, x):
        out = self.spike1(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.downsample(x)
        out = self.spike2(out)
        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, **kwargs_spikes):
        super(Bottleneck, self).__init__()
        self.kwargs_spikes = kwargs_spikes
        self.nb_steps = kwargs_spikes['nb_steps']
        self.conv1 = tdLayer(nn.Conv2d(in_planes, planes, kernel_size=1, bias=False), nb_steps=self.nb_steps)
        self.bn1 = tdBatchNorm(nn.BatchNorm2d(planes))
        self.spike1 = LIF(kwargs_spikes)
        self.conv2 = tdLayer(nn.Conv2d(planes, planes, kernel_size=3,
                                       stride=stride, padding=1, bias=False), nb_steps=self.nb_steps)
        self.bn2 = tdBatchNorm(nn.BatchNorm2d(planes))
        self.spike2 = LIF(kwargs_spikes)

        self.conv3 = tdLayer(nn.Conv2d(planes, self.expansion *
                                       planes, kernel_size=1, bias=False), nb_steps=self.nb_steps)
        self.bn3 = tdBatchNorm(nn.BatchNorm2d(planes * self.expansion), alpha=1 / math.sqrt(2.))

        self.shortcut = tdBatchNorm(nn.BatchNorm2d(planes * Bottleneck.expansion), alpha=1 / math.sqrt(2.))
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                tdLayer(nn.Conv2d(in_planes, self.expansion * planes,
                                  kernel_size=1, stride=stride, bias=False), self.nb_steps),
                tdBatchNorm(nn.BatchNorm2d(self.expansion * planes), alpha=1 / math.sqrt(2.))
            )
        self.spike3 = LIF(kwargs_spikes)

    def forward(self, x):
        out = self.spike1(self.bn1(self.conv1(x)))
        out = self.spike2(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = self.spike3(out)
        return out


class SResNet(nn.Module):
    def __init__(self, block, num_blocks, num_class=10, **kwargs_spikes):
        super(SResNet, self).__init__()
        self.in_planes = 64
        self.kwargs_spike = kwargs_spikes
        self.nb_steps = kwargs_spikes['nb_steps']
        self.conv1 = tdLayer(nn.Conv2d(3, 64, kernel_size=3,
                                       stride=1, padding=1, bias=False), nb_steps=self.nb_steps)
        self.bn1 = tdBatchNorm(nn.BatchNorm2d(64), 1)
        self.spike0 = LIF(**kwargs_spikes)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avg_pool = tdLayer(nn.AdaptiveAvgPool2d((1, 1)), nb_steps=self.nb_steps)
        self.linear = tdLayer(nn.Linear(512 * block.expansion, num_class), nb_steps=self.nb_steps)
        self.readout = LIF(readout=True, **kwargs_spikes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, **self.kwargs_spike))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        self.reset_mask()
        x, _ = torch.broadcast_tensors(x, torch.zeros((self.nb_steps,) + x.shape))
        x = x.permute(1, 2, 3, 4, 0)
        x = self.spike0(self.bn1(self.conv1(x.float())))
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.shape[0], -1, out.shape[-1])
        out = self.readout(self.linear(out))
        return out

    def reset_mask(self):
        for m in self.modules():
            if isinstance(m, NoisySpike):
                m.reset_mask()

    def set_noisy_rate(self, p):
        for m in self.modules():
            if isinstance(m, NoisySpike):
                m.p = p


imgnet_depth_lst = [18, 34, 50, 101, 152]


def cfg(depth):
    assert (depth in imgnet_depth_lst), "Error : Resnet depth should be either 18, 34, 50, 101, 152"
    cf_dict = {
        '18': (BasicBlock, [2, 2, 2, 2]),
        '34': (BasicBlock, [3, 4, 6, 3]),
        '50': (Bottleneck, [3, 4, 6, 3]),
        '101': (Bottleneck, [3, 4, 23, 3]),
        '152': (Bottleneck, [3, 8, 36, 3]),
    }

    return cf_dict[str(depth)]


def SResNetX(depth, num_class=10, **kwargs_spike):
    if depth in imgnet_depth_lst:
        block, num_blocks = cfg(depth)
        return SResNet(block, num_blocks, num_class, **kwargs_spike)
    else:
        pass

