import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.init as init
import numpy as np

import pdb

def conv_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.xavier_uniform(m.weight, gain=np.sqrt(2))
        init.constant(m.bias, 0)

def cfg(depth):
    depth_lst = [11, 13, 16, 19]
    assert (depth in depth_lst), "Error : VGGnet depth should be either 11, 13, 16, 19"
    cf_dict = {
        '11': [
            64, 'mp',
            128, 'mp',
            256, 256, 'mp',
            512, 512, 'mp',
            512, 512, 'mp'],
        '13': [
            64, 64, 'mp',
            128, 128, 'mp',
            256, 256, 'mp',
            512, 512, 'mp',
            512, 512, 'mp'
            ],
        '16': [
            64, 64, 'mp',
            128, 128, 'mp',
            256, 256, 256, 'mp',
            512, 512, 512, 'mp',
            512, 512, 512, 'mp'
            ],
        '19': [
            64, 64, 'mp',
            128, 128, 'mp',
            256, 256, 256, 256, 'mp',
            512, 512, 512, 512, 'mp',
            512, 512, 512, 512, 'mp'
            ],
    }

    return cf_dict[str(depth)]

def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)

class VGG(nn.Module):
    def __init__(self, depth, num_classes, dropout_rate = 0.):
        super(VGG, self).__init__()
#        self.features = self._make_layers(cfg(depth), dropout_rate)
        self.features = self._make_layers(cfg(depth))
        #self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(dropout_rate),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(dropout_rate),
            nn.Linear(4096, num_classes),
        )
#        pdb.set_trace()


        #self.linear = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


        #below to work with datasets that have resolution other than 32x32
        #m = torch.nn.AdaptiveAvgPool2d((1))
        #out = m(out)
        #  --end 

        #pdb.set_trace()
        #out = out.view(out.size(0), -1)
        #out = self.linear(out)

        #return out

    # def _make_layers(self, cfg, dropout_rate):
    #     layers = []
    #     in_planes = 3

    #     for x in cfg:
    #         if x == 'mp':
    #             layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
    #         else:
    #             if dropout_rate > 0.:
    #                 #pdb.set_trace()
    #                 layers += [conv3x3(in_planes, x), nn.BatchNorm2d(x), 
    #                     nn.ReLU(inplace=True), nn.Dropout(p=dropout_rate)]
    #             else:
    #                 layers += [conv3x3(in_planes, x), nn.BatchNorm2d(x), nn.ReLU(inplace=True)]
    #             in_planes = x

    #     # After cfg convolution
    #     layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
    #     return nn.Sequential(*layers)

    def _make_layers(self, cfg):
        layers = []
        in_planes = 3

        for x in cfg:
            if x == 'mp':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [conv3x3(in_planes, x), nn.BatchNorm2d(x), nn.ReLU(inplace=True)]
                in_planes = x

        # After cfg convolution
        #layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)



if __name__ == "__main__":
    net = VGG(16, 10)
    y = net(Variable(torch.randn(1,3,32,32)))
    print(y.size())
