# -*- coding: utf-8 -*-
from model.layer import *

feature_cfg = {
    'VGG5': [64, 'A', 128, 128, 'A', 'AA'],
    'VGG9': [64, 'A', 128, 256, 'A', 256, 512, 'A', 512, 'A', 512],
    'VGG11': [64, 'A', 128, 256, 'A', 512, 512, 'A', 512, 'A', 512, 512, 'AA'],
    'VGG13': [64, 64, 'A', 128, 128, 'A', 256, 256, 'A', 512, 512, 512, 'A', 512, 'AA'],
    'VGG16': [64, 64, 'A', 128, 128, 'A', 256, 256, 256, 'A', 512, 512, 512, 'A', 512, 512, 512],
    'VGG19': [64, 64, 'A', 128, 128, 'A', 256, 256, 256, 256, 'A', 512, 512, 512, 512, 'A', 512, 512, 512, 512],
    'CIFAR': [128, 256, 'A', 512, 'A', 1024, 512],
    'VGGSNN2': [64, 128, 'A', 256, 256, 'A', 512, 512, 'A', 512, 512, 'A'],
}

clasifier_cfg = {
    'VGG16': [2048, 4096, 4096, 10],
    'VGG5': [128, 10],
    'VGG11': [512, 10],
    'VGG13': [512, 10],
    'VGG19': [2048, 4096, 4096, 10],
    'VGGSNN2': [4608, 10]
}


class VGG(nn.Module):
    def __init__(self, architecture='VGG16', kernel_size=3, in_channel=3, use_bias=True,
                 bn_type=None, num_class=10, readout_mode='psp_avg',
                 **kwargs_spikes):
        super(VGG, self).__init__()
        self.kwargs_spikes = kwargs_spikes
        self.nb_steps = kwargs_spikes['timestep']
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.use_bias = use_bias
        self.bn_type = bn_type
        self.readout_mode = readout_mode
        self.num_class = num_class
        clasifier_cfg[architecture][-1] = num_class
        self.feature = self._make_feature(feature_cfg[architecture])
        self.classifier = self._make_classifier(clasifier_cfg[architecture])
        self._initialize_weights()

        self.rate_flag = kwargs_spikes.get("rate_flag", False)
        wrap_model(self, time_step=self.nb_steps, rate_flag=self.rate_flag)
        self.register_forward_hook(affine_forward_hook)

    def _make_feature(self, config):
        layers = []
        channel = self.in_channel
        for x in config:
            if x == 'A':
                layers.append(nn.AvgPool2d(kernel_size=2, stride=2))
            elif x == 'AA':
                layers.append(nn.AdaptiveAvgPool2d((1, 1)))

            else:
                layers.append(nn.Conv2d(in_channels=channel, out_channels=x, kernel_size=self.kernel_size,
                                        stride=1, padding=self.kernel_size // 2, bias=self.use_bias))

                layers.append(nn.BatchNorm2d(x))
                layers.append(LIFLayer(**self.kwargs_spikes))
                channel = x
        return SequentialModule(*layers)

    def _make_classifier(self, config):
        layers = []
        for i in range(len(config) - 1):
            layers.append(nn.Linear(config[i], config[i + 1], bias=self.use_bias))
            layers.append(LIFLayer(**self.kwargs_spikes))
        layers.pop()

        return SequentialModule(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.5)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.5)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        if x.dim() <= 4:
            out, _ = torch.broadcast_tensors(x, torch.zeros((self.nb_steps,) + x.shape))
        else:
            out = x.permute(1, 0, 2, 3, 4)

        rate = out.mean(dim=0)
        out, rate = self.feature(out, rate)
        out, rate = out.view(out.shape[0], out.shape[1], -1), rate.view(rate.shape[0], -1)
        out, rate = self.classifier(out, rate)
        return out, rate
