import math
import sys
sys.path.append('../')
sys.path.append('../../')
import torch.nn.functional as F
import torch.nn as nn
from set.operations import *
from set.set import *
from util.flops import *
from util.memory import *
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'A'],
}


class GatedNet(nn.Module):
    def __init__(self):
        super(GatedNet, self).__init__()
        self.gated_layers = []
        self.use_gate = False
        self.full_flops = 1.
        self.full_mem = 1.

    def build_gate(self, gate_fn, argdicts={}):
        self.use_gate = True
        if isinstance(argdicts, dict):
            argdicts =  [argdicts]*len(self.gated_layers)
        for i, layer in enumerate(self.gated_layers):
            layer.build_gate(gate_fn, **argdicts[i])

    def build_gate_dep(self, gate_fn, argdicts={}):
        self.use_gate = True
        if isinstance(argdicts, dict):
            argdicts =  [argdicts]*len(self.gated_layers)
        for i, layer in enumerate(self.gated_layers):
            layer.build_gate_dep(gate_fn, **argdicts[i])

    def reset_dep(self):
        for i, layer in enumerate(self.gated_layers):
            layer.dgate.reset()

    def get_reg(self):
        reg = 0.
        for layer in self.gated_layers:
            reg += layer.get_reg()
        return reg

    def get_reg_dep(self):
        reg = 0.
        for layer in self.gated_layers:
            reg += layer.get_reg_dep()
        return reg

    def get_pruned_size(self):
        return [layer.get_num_active() for layer in self.gated_layers]
    def get_pruned_weight(self):
        return [layer.get_weight_nonactive() for layer in self.gated_layers]
    def get_pruned_weight_sum(self):
        sum=0
        for layer in self.gated_layers:
            sum+=layer.get_weight_nonactive()
        return sum

    def get_weight(self):
        return [layer.get_weight() for layer in self.gated_layers]

    def get_pruned_size_dep(self):
        return [int(layer.dgate.num_active) for layer in self.gated_layers]

    def count_flops(self, num_units):
        raise NotImplementedError

    def count_flops_dep(self, num_units, num_units_dep):
        raise NotImplementedError

    def count_memory(self, num_units):
        raise NotImplementedError

    def count_memory_dep(self, num_units, num_units_dep):
        raise NotImplementedError


    def get_speedup(self):
        pruned = self.get_pruned_size()
        return float(self.full_flops) / float(self.count_flops(pruned))

    def get_speedup_dep(self):
        pruned = self.get_pruned_size()
        pruned_dep = self.get_pruned_size_dep()
        return float(self.full_flops) / \
                float(self.count_flops_dep(pruned, pruned_dep))

    def get_memory_saving(self):
        pruned = self.get_pruned_size()
        return float(self.count_memory(pruned)) / float(self.full_mem)

    def get_memory_saving_dep(self):
        pruned = self.get_pruned_size()
        pruned_dep = self.get_pruned_size_dep()
        return float(self.count_memory_dep(pruned, pruned_dep)) / \
                float(self.full_mem)

class VGG_CIFAR(GatedNet):

    def __init__(self, depth=19, num_classes=10,logger=None):
        super(VGG_CIFAR, self).__init__()
        self.cfg = cfg['VGG' + str(depth)]
        self.features = self.make_layers(self.cfg)
        self.classifier = nn.Linear(512, num_classes)
        self._initialize_weights()
        list=[]
        for i in self.cfg:
            if i != 'M' and i!= 'A':
                list.append(i)
        self.num_classes=num_classes
        self.full_flops = self.count_flops(list)
        self.full_mem = self.count_memory(list)
        pytorch_total_params = sum(p.numel() for p in self.parameters())
        if logger !=None:
            logger.info(pytorch_total_params)
            logger.info("flops: {}".format(self.full_flops))
            logger.info("mem: {}".format(self.full_mem))


    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

    def make_layers(self, cfg):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            elif v == 'A':
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = GatedConv2d(in_channels, v, kernel_size=3, padding=1)
                self.gated_layers.append(conv2d)
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                in_channels = v

        return nn.ModuleList(layers)

    def set_apply(self):
        self.set_func=SetTransformer(3072,1,3072)
        return True
    def count_flops(self, num_units):
        flops=0
        idx=0
        height=32
        num_units=[3]+num_units
        for v in self.cfg:
            if v == 'M':
                flops += count_flops_max_pool(height, height, num_units[idx], kernel_size=2, stride=2)
                height=height/2
            elif v == 'A':
                flops += count_flops_max_pool(height, height, num_units[idx], kernel_size=2, stride=2)
                height=height/2
            else:
                flops+= count_flops_conv(height, height,num_units[idx] ,num_units[idx+1], 3, padding=1)
                idx=idx+1
        flops+=count_flops_dense(512,self.num_classes)


        return flops

    def count_flops_dep(self, num_units, num_units_dep):

        return self.count_flops(num_units_dep)

    def count_memory(self, num_units):
        mem=0
        idx=0
        height=32
        num_units=[3]+num_units
        for v in self.cfg:
            if v == 'M':
                height=height/2
            elif v == 'A':
                height=height/2
            else:
                mem+= count_memory_conv(height, height,num_units[idx] ,num_units[idx+1], 3, padding=1)
                idx=idx+1
        mem+=count_memory_dense(512,self.num_classes)

        return mem
    def count_memory_dep(self, num_units, num_units_dep):

        return self.count_memory(num_units_dep)

    def forward(self, x,s):
        s=self.set_func(s).view(1,3,32,32)
        #s=(s-s.mean())/torch.sqrt(s.var()+1e-5)  #new
        layer=0
        for layer in self.features:
            if isinstance(layer,nn.MaxPool2d):
                x=layer(x)
                s=F.max_pool2d(s,kernel_size=2, stride=2)

            elif isinstance(layer, nn.AvgPool2d):
                x=layer(x)
                s=F.avg_pool2d(s,kernel_size=2, stride=2)

            elif isinstance(layer,GatedConv2d) :
                x,s = layer(x,s)
            elif isinstance(layer,nn.BatchNorm2d):
                x=layer(x)
                #s=(s-s.mean())/torch.sqrt(s.var()+1e-5)  #new
            elif isinstance(layer,nn.ReLU):
                x=layer(x)
                s=F.relu(s)



        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
class VGG_CIFAR_PRUNED_GATED(GatedNet):

    def __init__(self, depth=19, num_classes=10,pruned=None,logger=None):
        super(VGG_CIFAR_PRUNED_GATED, self).__init__()


        self.cfg = cfg['VGG' + str(depth)]
        idx=0
        pidx=0
        for i in self.cfg:
            if i != 'M' and i != 'A':
                self.cfg[idx]=pruned[pidx]
                idx=idx+1
                pidx=pidx+1
            else:
                idx=idx+1



        self.features = self.make_layers(self.cfg)
        self.classifier = nn.Linear(self.cfg[19], num_classes)


        self._initialize_weights()
        list=[]
        for i in self.cfg:
            if i != 'M' and i!= 'A':
                list.append(i)
        self.num_classes=num_classes
        self.full_flops = self.count_flops(list)
        self.full_mem = self.count_memory(list)
        pytorch_total_params = sum(p.numel() for p in self.parameters())
        if logger !=None:
            logger.info("PARAM: {}M ".format(pytorch_total_params/1000))
            logger.info("{:.3f}% PRUNED".format(pytorch_total_params*100/20040522))
            logger.info("flops: {} M".format(self.full_flops/1000))
            logger.info("mem: {} M".format(self.full_mem/1000))


    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

    def make_layers(self, cfg):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            elif v == 'A':
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = GatedConv2d(in_channels, v, kernel_size=3, padding=1)
                self.gated_layers.append(conv2d)
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                in_channels = v

        return nn.ModuleList(layers)

    def set_apply(self):
        self.set_func=SetTransformer(3072,1,3072)
        return True
    def count_flops(self, num_units):
        flops=0
        idx=0
        height=32
        num_units=[3]+num_units
        for v in self.cfg:
            if v == 'M':
                flops += count_flops_max_pool(height, height, num_units[idx], kernel_size=2, stride=2)
                height=height/2
            elif v == 'A':
                flops += count_flops_max_pool(height, height, num_units[idx], kernel_size=2, stride=2)
                height=height/2
            else:
                flops+= count_flops_conv(height, height,num_units[idx] ,num_units[idx+1], 3, padding=1)
                idx=idx+1
        flops+=count_flops_dense(512,self.num_classes)


        return flops

    def count_flops_dep(self, num_units, num_units_dep):

        return self.count_flops(num_units_dep)

    def count_memory(self, num_units):
        mem=0
        idx=0
        height=32
        num_units=[3]+num_units
        for v in self.cfg:
            if v == 'M':
                height=height/2
            elif v == 'A':
                height=height/2
            else:
                mem+= count_memory_conv(height, height,num_units[idx] ,num_units[idx+1], 3, padding=1)
                idx=idx+1
        mem+=count_memory_dense(512,self.num_classes)

        return mem
    def count_memory_dep(self, num_units, num_units_dep):

        return self.count_memory(num_units_dep)

    def forward(self, x,s):
        s=self.set_func(s).view(1,3,32,32)
        #s=(s-s.mean())/torch.sqrt(s.var()+1e-5)  #new
        layer=0
        for layer in self.features:
            if isinstance(layer,nn.MaxPool2d):
                x=layer(x)
                s=F.max_pool2d(s,kernel_size=2, stride=2)

            elif isinstance(layer, nn.AvgPool2d):
                x=layer(x)
                s=F.avg_pool2d(s,kernel_size=2, stride=2)

            elif isinstance(layer,GatedConv2d) :
                x,s = layer(x,s)
            elif isinstance(layer,nn.BatchNorm2d):
                x=layer(x)
                #s=(s-s.mean())/torch.sqrt(s.var()+1e-5)  #new
            elif isinstance(layer,nn.ReLU):
                x=layer(x)
                s=F.relu(s)



        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
