import torch
import torch.nn as nn
import numpy as np
from .convolutional import Convolutional
from .dense import Dense
from .pooling import MaxPool,AvgPool
from .batch_norm import Batch_Norm2d
from .relu import ReLU
import torch.nn.functional as F

def conv3x3(in_planes, out_planes, stride=1,bias=False):
    """3x3 convolution with padding"""
    return Convolutional(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=bias)


def conv1x1(in_planes, out_planes, stride=1,bias=False):
    """1x1 convolution"""
    return Convolutional(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None,beta=None,BN=True):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride,bias=not BN)
        self.BN=BN
        if BN:
            self.bn1 = Batch_Norm2d(planes)
            self.bn2 = Batch_Norm2d(planes)
        self.conv2 = conv3x3(planes, planes,bias=not BN)
        self.downsample = downsample
        self.stride = stride
        self.beta=beta
        self.relu0=ReLU(beta)
        self.relu1=ReLU(beta)
        self.X=0
        self.h0=0
        self.h1=1

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        if self.BN:
            out = self.bn1(out)
        out = self.relu0(out)

        out = self.conv2(out)
        if self.BN:
            out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu1(out)

        return out

    def analyze(self, method, R):
        R0=self.relu1.analyze(method,R)

        #============
        if self.downsample is not None:
            R1=self.downsample.analyze(method,R0)
        else:
            R1=R0
        #=====================
        if self.BN:
            R2=self.bn2.analyze(method,R0)
        else:
            R2=R0
        R2=self.conv2.analyze(method,R2)
        R2=self.relu0.analyze(method,R2)
        if self.BN:
            R2 = self.bn1.analyze(method, R2)
        R2 = self.conv1.analyze(method, R2)
        R=R1+R2
        return R


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None,beta=None,BN=True):
        super(Bottleneck, self).__init__()
        self.BN=BN
        self.conv1 = conv1x1(inplanes, planes,bias=not BN)
        if BN:
            self.bn1 = Batch_Norm2d(planes)
            self.bn2 = Batch_Norm2d(planes)
            self.bn3 = Batch_Norm2d(planes * self.expansion)
        self.conv2 = conv3x3(planes, planes, stride,bias=not BN)
        self.conv3 = conv1x1(planes, planes * self.expansion,bias=not BN)
        self.downsample = downsample
        self.stride = stride
        self.relu0 = ReLU(beta)
        self.relu1 = ReLU(beta)
        self.relu2 = ReLU(beta)
        self.X = 0
        self.h0 = 0
        self.h1 = 1
        self.h2=2

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        if self.BN:
            out = self.bn1(out)
        out = self.relu0(out)

        out = self.conv2(out)
        if self.BN:
            out = self.bn2(out)
        out = self.relu1(out)

        out = self.conv3(out)
        if self.BN:
            out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu2(out)

        return out

    def analyze(self, method, R):
        R0=self.relu2.analyze(method,R)

        # ============
        if self.downsample is not None:
            R1 = self.downsample.analyze(method, R0)
        else:
            R1 = R0
        # =====================
        R2=R0
        if self.BN:
            R2 = self.bn3.analyze(method, R0)
        R2 = self.conv3.analyze(method, R2)
        R2=self.relu1.analyze(method,R2)
        if self.BN:
            R2 = self.bn2.analyze(method, R2)
        R2 = self.conv2.analyze(method, R2)
        R2=self.relu0.analyze(method,R2)
        if self.BN:
            R2 = self.bn1.analyze(method, R2)
        R2 = self.conv1.analyze(method, R2)
        R = R1 + R2
        #print(R.size())
        return R

class Downsample(nn.Module):
    def __init__(self,inplanes,planes,expansion,stride,BN=True):
        super(Downsample,self).__init__()
        self.BN=BN
        self.conv=conv1x1(inplanes, planes * expansion, stride,bias=not BN)
        if self.BN:
            self.bn=Batch_Norm2d((planes *expansion))

    def forward(self,x):
        x=self.conv(x)
        if self.BN:
            x=self.bn(x)
        return x

    def analyze(self,method,R):
        #print(R.size(0))
        if self.BN:
            R=self.bn.analyze(method,R)
        R=self.conv.analyze(method,R)
        return R
