import random
from models.layers import *

import torch
import torch.nn as nn


class vgg11(nn.Module):
    def __init__(self,  num_classes, norm):
        super(vgg11, self).__init__()
        if norm is not None and isinstance(norm, tuple):
            self.norm = TensorNormalization(*norm)
        else:
            raise AssertionError("Invalid normalization")
        self.T = 4
        self.W = 16

        self.layer1 = Layer(3, 64, 3, 1, 1)
        self.layer2 = Layer(64, 128, 3, 1, 1)
        self.layer3 = Layer(128, 256, 3, 1, 1)
        self.layer4 = Layer(256, 512, 3, 1, 1)
        self.layer5 = Layer(512, 512, 3, 1, 1)
        self.layer6 = Layer(512, 512, 3, 1, 1)
        self.layer7 = Layer(512, 512, 3, 1, 1)
        self.layer8 = Layer(512, 512, 3, 1, 1)

        self.avgpool = SeqToANNContainer(nn.AvgPool2d(2))
        self.flatten = nn.Flatten(2)
        self.fc1 = SeqToANNContainer(nn.Linear(512 * self.W, 4096))
        self.neuron1 = LIFSpike()
        self.fc2 = SeqToANNContainer(nn.Linear(4096, 4096))
        self.neuron2 = LIFSpike()
        self.fc3 = SeqToANNContainer(nn.Linear(4096,num_classes))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, val=1)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.norm(x)
        x = add_dimention(x, self.T)
        v = 0

        x, v1 = self.layer1(x)
        v = v + v1

        x = self.avgpool(x)

        x, v1 = self.layer2(x)
        v = v + v1

        x, v1 = self.layer3(x)
        v = v + v1

        x = self.avgpool(x)

        x, v1 = self.layer4(x)
        v = v + v1

        x, v1 = self.layer5(x)
        v = v + v1

        x, v1 = self.layer6(x)
        v = v + v1

        x = self.avgpool(x)

        x, v1 = self.layer7(x)
        v = v + v1

        x, v1 = self.layer8(x)
        v = v + v1

        x = self.flatten(x)

        x = self.fc1(x)

        x, v1 = self.neuron1(x)
        v = v + v1

        x = self.fc2(x)

        x, v1 = self.neuron2(x)
        v = v + v1

        x = self.fc3(x)

        return x, v
    