
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F

class ResidualBlock_p(nn.Module):
    def __init__(self, inp, oup,oup2, stride=1):
        super(ResidualBlock_p, self).__init__()
        self.stride=stride
        self.inp=inp
        self.oup2=oup2

        self.conv = nn.Conv2d(inp, oup, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(oup)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2d(oup, oup2, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(oup2)
        if self.inp!=self.oup2 or stride>1:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inp, oup2, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(oup2))


    def forward(self, x):

        residual =x
        out=self.relu(self.bn(self.conv(x)))
        out=self.bn2(self.conv2(out))

        if self.inp!=self.oup2 or self.stride>1:

            out =out + self.shortcut(residual)

        else:
            out= out + residual

        out = self.relu(out)


        return out

class RESNET_PRUNED(nn.Module):
    def __init__(self, num_classes=10,P=[64,64,64,64,128,128,128,128,256,256,256,256,512,512,512,512],logger=None):
        super(RESNET_PRUNED, self).__init__()

        block=ResidualBlock_p
        num_blocks=[2,2,2,2]
        self.inp = 64
        self.conv0 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn0 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)  # compute on input (mem save)

        self.idx=0
        self.P=P
        self.layer1 = self._make_layer(block,  num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block,  num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block,  num_blocks[3], stride=2)
        self.linear = nn.Linear(self.inp, num_classes)

        pytorch_total_params = sum(p.numel() for p in self.parameters())
        if logger !=None:
            logger.info("PARAM: {}M ( {:.3f}% )".format(pytorch_total_params/1000000,pytorch_total_params*100/11173962))



    def _make_layer(self, block, num_block, stride=1):
        layers = []
        # stride when out features*2 is 2 = depth *2 and w,h /2
        strides = [stride] + [1]*(num_block-1)
        for stride in strides:
            if self.P[self.idx]!=0:
                layers.append(block(self.inp, self.P[self.idx],self.P[self.idx+1], stride))
                self.inp = self.P[self.idx+1]
                self.idx+=2
        return nn.Sequential(*layers)

    def forward(self, x):

        out = self.relu(self.bn0(self.conv0(x)))  # layer1
        out = self.layer1(out)   # layer 2,3,4,5  (2 plain blocks )
        out = self.layer2(out)   # layer 6,7,8,9  (2 plain blocks )
        out = self.layer3(out)   # layer 10,11,12,13 (2 plain blocks )
        out = self.layer4(out)  #layer  14,15,16,17 (2 plain blocks )
        out = F.adaptive_avg_pool2d(out,(1,1))
        out = out.view(out.size(0), -1)
        out = self.linear(out)   #layer 18
        return out

class RESNET_PRUNED_I(nn.Module):
    def __init__(self, num_classes=10,P=None,logger=None):
        super(RESNET_PRUNED_I, self).__init__()
        if P==None:
            P=[64,64,64,64,128,128,128,128,256,256,256,256,512,512,512,512]

        block=ResidualBlock_p
        num_blocks=[2,2,2,2]
        self.inp = 64
        self.conv0 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn0 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)  # compute on input (mem save)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.idx=0
        self.P=P
        self.layer1 = self._make_layer(block,  num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block,  num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block,  num_blocks[3], stride=2)
        self.avgpool=nn.AdaptiveAvgPool2d((1,1))

        self.linear = nn.Linear(self.inp, num_classes)

        pytorch_total_params = sum(p.numel() for p in self.parameters())
        if logger !=None:
            logger.info("PARAM: {}M ( {:.3f}% )".format(pytorch_total_params/1000000,pytorch_total_params*100/11181642))
        else:
            print("PARAM: {}M ( {:.3f}% )".format(pytorch_total_params/1000000,pytorch_total_params*100/11181642))


    def _make_layer(self, block, num_block, stride=1):
        layers = []
        # stride when out features*2 is 2 = depth *2 and w,h /2
        strides = [stride] + [1]*(num_block-1)
        for stride in strides:
            if self.P[self.idx]!=0:
                layers.append(block(self.inp, self.P[self.idx],self.P[self.idx+1], stride))
                self.inp = self.P[self.idx+1]
                self.idx+=2
        return nn.Sequential(*layers)

    def forward(self, x):

        out = self.relu(self.bn0(self.conv0(x)))  # layer1
        out= self.maxpool(out)
        out = self.layer1(out)   # layer 2,3,4,5  (2 plain blocks )
        out = self.layer2(out)   # layer 6,7,8,9  (2 plain blocks )
        out = self.layer3(out)   # layer 10,11,12,13 (2 plain blocks )
        out = self.layer4(out)  #layer  14,15,16,17 (2 plain blocks )
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)   #layer 18
        return out
