import math

import torch.nn as nn
import torch.nn.functional as F

__all__ = ['VGG',  'vgg16_bn','vgg16_under_d2_bn', 'vgg16_under_d4_bn', 'vgg16_under_d8_bn', 'vgg16_under_d16_bn']


#https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/vgg.py

class VGG(nn.Module):
    '''
    VGG model 
    '''
    def __init__(self, features,d=1, num_classes=10):
        super(VGG, self).__init__()
        self.features = features
        self.dropout1=nn.Dropout()
        self.linear1=nn.Linear(int(512*d), int(512*d))
        #nn.ReLU(True)
        self.dropout2=nn.Dropout()
        self.linear2=nn.Linear(int(512*d), int(512*d))
        ##nn.ReLU(True)
        self.linear3=nn.Linear(int(512*d), num_classes)
      


        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                m.bias.data.zero_()


    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout1(x)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.linear2(x)
        x = F.relu(x)
        x = self.linear3(x)
        return x

    def get_features(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout1(x)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.linear2(x)
        x = F.relu(x)
        return x


def make_layers(cfg,d=1, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, int(d*v), kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(int(d*v)), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = int(d*v)
    return nn.Sequential(*layers)


cfg = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 
          512, 512, 512, 512, 'M'],


}


def vgg16_bn(**kwargs):
    """VGG 16-layer model (configuration "D") with batch normalization"""
    return VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)

def vgg16_under_d2_bn(**kwargs):
    """VGG 16-layer model (configuration "F")"""
    return VGG(make_layers(cfg['D'],d=1/2, batch_norm=True), d=1/2, **kwargs)


def vgg16_under_d4_bn(**kwargs):
    """VGG 16-layer model (configuration "F")"""
    return VGG(make_layers(cfg['D'],d=1/4, batch_norm=True), d=1/4, **kwargs)


def vgg16_under_d8_bn(**kwargs):
    """VGG 16-layer model (configuration "F")"""
    return VGG(make_layers(cfg['D'],d=1/8, batch_norm=True), d=1/8, **kwargs)


def vgg16_under_d16_bn(**kwargs):
    """VGG 16-layer model (configuration "F")"""
    return VGG(make_layers(cfg['D'],d=1/16, batch_norm=True), d=1/16, **kwargs)








 
