import math
import sys
import torch.nn as nn
sys.path.append('../')
sys.path.append('../../')
from module.model import *
from module.operations import *

cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'A'],
}

__all__ = ["VGG_CIFAR"]


class VGG_CIFAR(nn.Module):

    def __init__(self, depth=19, num_classes=100):
        super(VGG_CIFAR, self).__init__()
        self.cfg = cfg['VGG' + str(depth)]
        self.features = self.make_layers(self.cfg)
        self.classifier = nn.Linear(512, num_classes)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

    def make_layers(self, cfg):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            elif v == 'A':
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                in_channels = v

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x



class VGG_CIFAR_gated(GatedNet):

    def __init__(self, depth=19, num_classes=100,logger=None):
        super(VGG_CIFAR_gated, self).__init__()
        self.cfg = cfg['VGG' + str(depth)]
        self.features = self.make_layers(self.cfg)
        self.classifier = nn.Linear(512, num_classes)
        self._initialize_weights()
        list=[]
        for i in self.cfg:
            if i != 'M' and i!= 'A':
                list.append(i)
        self.num_classes=num_classes
        self.full_flops = self.count_flops(list)
        self.full_mem = self.count_memory(list)
        if logger is not None:
            logger.info("flops: {}".format(self.full_flops))
            logger.info("mem: {}".format(self.full_mem))

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

    def make_layers(self, cfg):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            elif v == 'A':
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = GatedConv2d(in_channels, v, kernel_size=3, padding=1)
                self.gated_layers.append(conv2d)
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                in_channels = v

        return nn.Sequential(*layers)
    def count_flops(self, num_units):
        flops=0
        idx=0
        height=32
        num_units=[3]+num_units
        for v in self.cfg:
            if v == 'M':
                flops += count_flops_max_pool(height, height, num_units[idx], kernel_size=2, stride=2)
                height=height/2
            elif v == 'A':
                flops += count_flops_max_pool(height, height, num_units[idx], kernel_size=2, stride=2)
                height=height/2
            else:
                flops+= count_flops_conv(height, height,num_units[idx] ,num_units[idx+1], 3, padding=1)
                idx=idx+1
        flops+=count_flops_dense(num_units[-1],self.num_classes)


        return flops

    def count_flops_dep(self, num_units, num_units_dep):

        return self.count_flops(num_units_dep)

    def count_memory(self, num_units):
        mem=0
        idx=0
        height=32
        num_units=[3]+num_units
        for v in self.cfg:
            if v == 'M':
                height=height/2
            elif v == 'A':
                height=height/2
            else:
                mem+= count_memory_conv(height, height,num_units[idx] ,num_units[idx+1], 3, padding=1)
                idx=idx+1
        mem+=count_memory_dense(num_units[-1],self.num_classes)

        return mem
    def count_memory_dep(self, num_units, num_units_dep):

        return self.count_memory(num_units_dep)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
