import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from collections import OrderedDict
from torch.nn import init
import math

import numpy as np
#coding:utf8
import torch
import torchvision

import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models

import numpy as np

def print_model_parm_flops(one_shot_model):
    prods = {}
    def save_hook(name):
        def hook_per(self, input, output):
            prods[name] = np.prod(input[0].shape)
        return hook_per

    list_1=[]
    def simple_hook(self, input, output):
        list_1.append(np.prod(input[0].shape))
    list_2={}
    def simple_hook2(self, input, output):
        list_2['names'] = np.prod(input[0].shape)


    multiply_adds = False
    list_conv=[]
    def conv_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()

        kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)
        bias_ops = 1 if self.bias is not None else 0

        params = output_channels * (kernel_ops + bias_ops)
        flops = batch_size * params * output_height * output_width

        list_conv.append(flops)


    list_linear=[]
    def linear_hook(self, input, output):
        batch_size = input[0].size(0) if input[0].dim() == 2 else 1

        weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
        bias_ops = self.bias.nelement()

        flops = batch_size * (weight_ops + bias_ops)
        list_linear.append(flops)

    list_bn=[]
    def bn_hook(self, input, output):
        list_bn.append(input[0].nelement())

    list_relu=[]
    def relu_hook(self, input, output):
        list_relu.append(input[0].nelement())

    list_pooling=[]
    def pooling_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()

        kernel_ops = self.kernel_size * self.kernel_size
        bias_ops = 0
        params = output_channels * (kernel_ops + bias_ops)
        flops = batch_size * params * output_height * output_width

        list_pooling.append(flops)

    def foo(net):
        childrens = list(net.children())
        if not childrens:
            if isinstance(net, torch.nn.Conv2d):
                net.register_forward_hook(conv_hook)
            if isinstance(net, torch.nn.Linear):
                net.register_forward_hook(linear_hook)
            if isinstance(net, torch.nn.BatchNorm2d):
                net.register_forward_hook(bn_hook)
            if isinstance(net, torch.nn.ReLU):
                net.register_forward_hook(relu_hook)
            if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
                net.register_forward_hook(pooling_hook)
            return
        for c in childrens:
            foo(c)

    foo(one_shot_model)
    input = Variable(torch.rand(3,32,32).unsqueeze(0), requires_grad = True).cuda()
    out = one_shot_model(input)

    total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling))
    M_flops = total_flops / 1e6
    #print('  + Number of FLOPs: %.2fM' % (M_flops))

    return M_flops


cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'A'],
}



class VGG_CIFAR_PRUNED(nn.Module):
    def __init__(self, depth=19, num_classes=10, pruned=None,logger=None):
        super(VGG_CIFAR_PRUNED, self).__init__()
        self.cfg = cfg['VGG' + str(depth)]
        idx=0
        pidx=0
        for i in self.cfg:
            if i != 'M' and i != 'A':
                self.cfg[idx]=pruned[pidx]
                idx=idx+1
                pidx=pidx+1
            else:
                idx=idx+1



        self.features = self.make_layers(self.cfg)
        self.classifier = nn.Linear(self.cfg[19], num_classes)
        self._initialize_weights()
        pytorch_total_params = sum(p.numel() for p in self.parameters())
        if logger !=None:
            logger.info("PARAM: {}M ( {:.3f}% )".format(pytorch_total_params/1000000,pytorch_total_params*100/20040522))



    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

    def make_layers(self, cfg):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            elif v == 'A':
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                in_channels = v

        return nn.Sequential(*layers)

    def forward(self, x):

        x = self.features(x)
        x = x.view(x.size(0), -1)

        x = self.classifier(x)
        return x

"""
class RESNET_PRUNED(nn.Module):
    def __init__(self, num_classes=10,P=[64,64,64,64,128,128,128,128,256,256,256,256,512,512,512,512],logger=None):
        super(RESNET_PRUNED, self).__init__()

        block=ResidualBlock_p
        num_blocks=[2,2,2,2]
        self.inp = 64
        self.conv0 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn0 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)  # compute on input (mem save)

        self.idx=0
        self.P=P
        self.layer1 = self._make_layer(block,  num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block,  num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block,  num_blocks[3], stride=2)
        self.linear = nn.Linear(512, num_classes)

        pytorch_total_params = sum(p.numel() for p in self.parameters())
        if logger !=None:
            logger.info("PARAM: {}M ( {:.3f}% )".format(pytorch_total_params/1000000,pytorch_total_params*100/11173962))




    def _make_layer(self, block, num_block, stride=1):
        layers = []
        # stride when out features*2 is 2 = depth *2 and w,h /2
        strides = [stride] + [1]*(num_block-1)
        for stride in strides:
            if self.P[self.idx]!=0:
                layers.append(block(self.inp, self.P[self.idx],self.P[self.idx+1], stride))
            self.inp = self.P[self.idx+1]
            self.idx+=2
        return nn.Sequential(*layers)

    def forward(self, x):

        out = self.relu(self.bn0(self.conv0(x)))  # layer1
        out = self.layer1(out)   # layer 2,3,4,5  (2 plain blocks )
        out = self.layer2(out)   # layer 6,7,8,9  (2 plain blocks )
        out = self.layer3(out)   # layer 10,11,12,13 (2 plain blocks )
        out = self.layer4(out)  #layer  14,15,16,17 (2 plain blocks )
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)   #layer 18
        return out
"""
class ResidualBlock_p(nn.Module):
    def __init__(self, inp, oup,oup2, stride=1):
        super(ResidualBlock_p, self).__init__()
        self.stride=stride
        self.inp=inp
        self.oup2=oup2

        self.conv = nn.Conv2d(inp, oup, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(oup)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2d(oup, oup2, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(oup2)
        if self.inp!=self.oup2 or stride>1:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inp, oup2, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(oup2))


    def forward(self, x):

        residual =x
        out=self.relu(self.bn(self.conv(x)))
        out=self.bn2(self.conv2(out))

        if self.inp!=self.oup2 or self.stride>1:

            out =out + self.shortcut(residual)

        else:
            out= out + residual

        out = self.relu(out)


        return out

class RESNET_PRUNED(nn.Module):
    def __init__(self, num_classes=10,P=[64,64,64,64,128,128,128,128,256,256,256,256,512,512,512,512],logger=None):
        super(RESNET_PRUNED, self).__init__()

        block=ResidualBlock_p
        num_blocks=[2,2,2,2]
        self.inp = 64
        self.conv0 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn0 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)  # compute on input (mem save)

        self.idx=0
        self.P=P
        self.layer1 = self._make_layer(block,  num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block,  num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block,  num_blocks[3], stride=2)
        self.linear = nn.Linear(self.inp, num_classes)

        pytorch_total_params = sum(p.numel() for p in self.parameters())
        if logger !=None:
            logger.info("PARAM: {}M ( {:.3f}% )".format(pytorch_total_params/1000000,pytorch_total_params*100/11173962))



    def _make_layer(self, block, num_block, stride=1):
        layers = []
        # stride when out features*2 is 2 = depth *2 and w,h /2
        strides = [stride] + [1]*(num_block-1)
        for stride in strides:
            if self.P[self.idx]!=0:
                layers.append(block(self.inp, self.P[self.idx],self.P[self.idx+1], stride))
                self.inp = self.P[self.idx+1]
                self.idx+=2
        return nn.Sequential(*layers)

    def forward(self, x):

        out = self.relu(self.bn0(self.conv0(x)))  # layer1
        out = self.layer1(out)   # layer 2,3,4,5  (2 plain blocks )
        out = self.layer2(out)   # layer 6,7,8,9  (2 plain blocks )
        out = self.layer3(out)   # layer 10,11,12,13 (2 plain blocks )
        out = self.layer4(out)  #layer  14,15,16,17 (2 plain blocks )
        out = F.adaptive_avg_pool2d(out,(1,1))
        out = out.view(out.size(0), -1)
        out = self.linear(out)   #layer 18
        return out

def real_res(P,logger):
    flops1 = 556.594186
    model_for_flops=RESNET_PRUNED(P=P,logger=logger).cuda()
    flops2 = print_model_parm_flops(model_for_flops)

    logger.info("*********FLOPS ARE {}M  (x {:.3f})**********".format(flops2,flops1/flops2))
    return  flops2

def real(P,logger):
    flops1 = 399.17057
    model_for_flops=VGG_CIFAR_PRUNED(pruned=P,logger=logger).cuda()
    flops2 = print_model_parm_flops(model_for_flops)

    logger.info("*********FLOPS ARE {}M  (x {:.3f})**********".format(flops2,flops1/flops2))
