import math
from model.layer import *

cfg = {
    'VGG5': [64, 'A', 128, 128, 'A'],
    'VGG9': [64, 'A', 128, 256, 'A', 256, 512, 'A', 512, 'A', 512],
    'VGG11': [64, 'A', 128, 256, 'A', 512, 512, 'A', 512, 'A', 512, 512],
    'VGG13': [64, 64, 'A', 128, 128, 'A', 256, 256, 'A', 512, 512, 512, 'A', 512],
    'VGG16': [64, 64, 'A', 128, 128, 'A', 256, 256, 256, 'A', 512, 512, 512, 'A', 512, 512, 512],
    'VGG19': [64, 64, 'A', 128, 128, 'A', 256, 256, 256, 256, 'A', 512, 512, 512, 512, 'A', 512, 512, 512, 512],
    'CIFAR': [128, 256, 'A', 512, 'A', 1024, 512],

}


class VGG_block(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride, bias, **kwargs_spikes):
        super(VGG_block, self).__init__()
        self.conv = tdLayer(nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
                                      stride=stride, bias=bias), nb_steps=kwargs_spikes['nb_steps'])
        self.bn = TemporalBN(in_channels=out_channel, nb_steps=kwargs_spikes['nb_steps'])
        self.spike = LIF(**kwargs_spikes)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.spike(out)

        return out


class VGG(nn.Module):
    def __init__(self, vgg_name='VGG16', labels=10, dataset='CIFAR10', kernel_size=3, dropout=0,
                 use_bias=True, **kwargs_spikes):
        super(VGG, self).__init__()
        self.kwargs_spikes = kwargs_spikes
        self.nb_steps = kwargs_spikes['nb_steps']
        self.dataset = dataset
        self.kernel_size = kernel_size
        self.dropout = dropout
        self.use_bias = use_bias
        self.vgg_name = vgg_name
        self.features = self._make_layers(cfg[vgg_name])
        if vgg_name == 'CIFAR' and dataset != 'MNIST':
            self.classifier = nn.Sequential(
                tdLayer(nn.Linear(512 * 2 * 2, 1024, bias=use_bias), nb_steps=self.nb_steps),
                LIF(**kwargs_spikes),
                tdLayer(nn.Linear(1024, 512, bias=use_bias), nb_steps=self.nb_steps),
                LIF(**kwargs_spikes),
                tdLayer(nn.Linear(512, labels, bias=use_bias), nb_steps=self.nb_steps),
                LIF(readout=True, **kwargs_spikes),
            )
        self._initialize_weights2()

    def reset_mask(self):
        for m in self.modules():
            if isinstance(m, NoisySpike):
                m.reset_mask()

    def forward(self, x):
        self.reset_mask()
        x = self.features[1](self.features[0](x))
        out, _ = torch.broadcast_tensors(x, torch.zeros((self.nb_steps,) + x.shape))
        out = out.permute(1, 2, 3, 4, 0)
        out = self.features[2:](out)
        out = out.view(out.shape[0], -1, out.shape[4])
        out = self.classifier(out)
        return out

    def _initialize_weights2(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layers(self, cfg):
        layers = []
        k = 0

        if self.dataset == 'MNIST':
            in_channels = 1
        else:
            in_channels = 3
        for x in cfg:
            stride = 1

            if x == 'A':
                layers.append(tdLayer(nn.AvgPool2d(kernel_size=2, stride=2), nb_steps=self.nb_steps))
            else:
                if k == 0:
                    layers.append(
                        nn.Conv2d(in_channels, x, kernel_size=self.kernel_size, padding=(self.kernel_size - 1) // 2,
                                  stride=stride, bias=self.use_bias))
                    layers.append(nn.BatchNorm2d(x))
                else:
                    layers.append(
                        VGG_block(in_channels, x, kernel_size=self.kernel_size, stride=stride, bias=self.use_bias,
                                  **self.kwargs_spikes))
                in_channels = x
                k += 1
        if self.vgg_name == 'CIFAR':
            layers.append(tdLayer(nn.AdaptiveAvgPool2d((2, 2)), nb_steps=self.nb_steps)),

        return nn.Sequential(*layers)

    def set_noisy_rate(self, p):
        for m in self.modules():
            if isinstance(m, NoisySpike):
                m.p = p


if __name__ == '__main__':
    net = VGG('VGG16', dataset='CIFAR10')
    print(net)
