import math

import torch.nn as nn
import torch.nn.init as init

'''
Slight variation of: https://github.com/alecwangcq/EigenDamage-Pytorch
                    and https://github.com/liuzhuang13/slimming 

Commit: 0e2174f80294773e76a8cb73c0bd03f1b7fd2cc

Licence: MIT

'''

def _weights_init(m):
    classname = m.__class__.__name__
    # print(classname)
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, nn.BatchNorm2d):
        if m.weight is not None:
            m.weight.data.fill_(1.0)
            m.bias.data.zero_()

class VGG(nn.Module):
    '''
    VGG model
    '''
    def __init__(self, features,mult_array,width_factor = 1.0, num_classes=10):
        super(VGG, self).__init__()
        self.model_type = 'vgg19'
        self.macs_forward = 398136320
        self.features = features
        if 15 in mult_array:
            last_layer_size = 1024
        else:
            last_layer_size = 512

        self.classifier = nn.Linear(int(width_factor*last_layer_size), num_classes)

        self._initialize_weights()



    def forward(self, x):
        x = self.features(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x



    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(0.5)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

def make_layers(cfg, batch_norm=False,mult_arr=None,width_factor=1.0):
    layers = []

    layer_idx = 0
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            if layer_idx in mult_arr:
                v = 2*v
            v= int(width_factor*v)
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            layer_idx += 1
            in_channels = v
    return nn.Sequential(*layers)


cfg = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
          512, 512, 512, 512],
}

def vgg19_bn(mult_array=[],width_factor =1.0, **kwargs):
    """VGG 19-layer model (configuration 'E') with batch normalization"""
    model= VGG(make_layers(cfg['E'], batch_norm=True,mult_arr=mult_array,width_factor=width_factor),mult_array,width_factor, **kwargs)
    #model._initialize_weights()
    return model