import math
import sys
import torch.nn as nn
sys.path.append('../')
sys.path.append('../../')


cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'A'],
}



class VGG_CIFAR_PRUNED(nn.Module):
    def __init__(self, depth=19, num_classes=10, pruned=None,logger=None):
        super(VGG_CIFAR_PRUNED, self).__init__()
        self.cfg = cfg['VGG' + str(depth)]
        idx=0
        pidx=0
        for i in self.cfg:
            if i != 'M' and i != 'A':
                self.cfg[idx]=pruned[pidx]
                idx=idx+1
                pidx=pidx+1
            else:
                idx=idx+1



        self.features = self.make_layers(self.cfg)
        self.classifier = nn.Linear(self.cfg[19], num_classes)
        self._initialize_weights()
        pytorch_total_params = sum(p.numel() for p in self.parameters())
        if logger !=None:
            logger.info("PARAM: {}M ( {:.3f}% )".format(pytorch_total_params/1000000,pytorch_total_params*100/20040522))


    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

    def make_layers(self, cfg):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            elif v == 'A':
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                in_channels = v

        return nn.Sequential(*layers)

    def forward(self, x):

        x = self.features(x)
        x = x.view(x.size(0), -1)

        x = self.classifier(x)
        return x

class VGG_CIFAR_PRUNED_I(nn.Module):

    def __init__(self, depth=19, num_classes=10, pruned=None,logger=None):
        super(VGG_CIFAR_PRUNED_I, self).__init__()
        self.cfg = cfg['VGG' + str(depth)]
        idx=0
        pidx=0
        for i in self.cfg:
            if i != 'M' and i != 'A':
                self.cfg[idx]=pruned[pidx]
                idx=idx+1
                pidx=pidx+1
            else:
                idx=idx+1



        self.features = self.make_layers(self.cfg)
        self.classifier = nn.Linear(self.cfg[19], num_classes)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self._initialize_weights()
        pytorch_total_params = sum(p.numel() for p in self.parameters())
        if logger !=None:
            logger.info(pytorch_total_params)
            #logger.info(self)


    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

    def make_layers(self, cfg):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            elif v == 'A':
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                in_channels = v

        return nn.Sequential(*layers)

    def forward(self, x):

        x = self.features(x)
        x=self.avgpool(x)

        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
