
import torch
import torch.nn as nn
from dorefanet1d import *
from typing import Type, Any, Callable, List, Optional, Tuple, Union
import math


class LearnableBias(nn.Module):
    def __init__(self, out_chn):
        super(LearnableBias, self).__init__()
        self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True)

    def forward(self, x):
        out = x + self.bias.expand_as(x)
        return out

class XBasicBlock1D(nn.Module):
    """Basic 1D Block for resnet 18 and resnet 34
    """
    expansion = 1

    def __init__(self, in_channels, stride=1, w_bits=1, a_bits=1):
        super().__init__()

        self.binconv3x1 = nn.Sequential(
                LearnableBias(in_channels),
                QuantizationActivation(a_bits=a_bits),
                QuantizationConv2d(in_channels, in_channels*stride, kernel_size=(3, 1), stride=(stride, 1), padding=(1, 0), bias=False, w_bits=w_bits),
                nn.BatchNorm2d(in_channels*stride)
            )

        self.binconv1x3 = nn.Sequential(
                LearnableBias(in_channels*stride),
                QuantizationActivation(a_bits=a_bits),
                QuantizationConv2d(in_channels*stride, in_channels*stride, kernel_size=(1, 3), stride=(1, stride), padding=(0, 1), bias=False, w_bits=w_bits),
                nn.BatchNorm2d(in_channels*stride)
            )

        self.relu1= nn.Sequential(
                LearnableBias(in_channels*stride),
                nn.PReLU(in_channels*stride),
                LearnableBias(in_channels*stride)
            )

        self.relu2= nn.Sequential(
                LearnableBias(in_channels*stride),
                nn.PReLU(in_channels*stride),
                LearnableBias(in_channels*stride)
            )


        self.shortcut1 = nn.Sequential()
        self.shortcut2 = nn.Sequential()

        if stride != 1 :
            self.shortcut1 = nn.Sequential(
                nn.AvgPool2d(kernel_size=(2, 1), stride=(stride, 1)),
                nn.Conv2d(in_channels, in_channels * stride, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(in_channels * stride)
            )
            self.shortcut2 = nn.Sequential(
                nn.AvgPool2d(kernel_size=(1, 2), stride=(1, stride)),
                nn.Conv2d(in_channels * stride, in_channels * stride, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(in_channels * stride)
            )


    def forward(self, x):
        residual1 = x
        out = self.binconv3x1(x)
        out = out + self.shortcut1(residual1)
        out = self.relu1(out)
        residual2 = out
        out = self.binconv1x3(out)
        out = out + self.shortcut2(residual2)
        out = self.relu2(out)
        return out


class XBasicBlock2D(nn.Module):
    """Basic Block for resnet 18 and resnet 34

    """
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, w_bits=1, a_bits=1):
        super().__init__()

        #residual function
        self.binconv3x3 = nn.Sequential(
                LearnableBias(in_channels),
                QuantizationActivation(a_bits=a_bits),
                QuantizationConv2d(in_channels, out_channels, kernel_size=(3, 3), stride=stride, padding=(1, 1), bias=False, w_bits=w_bits),
                nn.BatchNorm2d(out_channels)
            )


        self.relu1= nn.Sequential(
                LearnableBias(out_channels),
                nn.PReLU(out_channels),
                LearnableBias(out_channels)
            )

        self.shortcut = nn.Sequential()

        if stride != 1 or in_channels != XBasicBlock2D.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.AvgPool2d(kernel_size=2, stride=stride),
                nn.Conv2d(in_channels, out_channels * XBasicBlock2D.expansion, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(out_channels * XBasicBlock2D.expansion)
            )

    def forward(self, x):
        residual1 = x
        out = self.binconv3x3(x)
        out = out + self.shortcut(residual1)
        out = self.relu1(out)
        return out


class ResNet(nn.Module):

    def __init__(self, block2d, block1d, num_block, w_bits, a_bits, num_classes=1000):
        super().__init__()

        self.in_channels = 64
        self.w_bits = w_bits
        self.a_bits = a_bits
        print("self.w_bits:", w_bits)
        print("self.a_bits:", a_bits)

        self.conv1 = nn.Sequential(
          nn.Conv2d(3, self.in_channels, 7, 2, 3, bias=False),
          nn.BatchNorm2d(self.in_channels),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        self.conv2_x = self._make_layer(block2d, block1d, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block2d, block1d, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block2d, block1d, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block2d, block1d, 512, num_block[3], 2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block2d, block1d, out_channels, num_blocks, stride):
        """make resnet layers(by layer i didnt mean this 'layer' was the
        same as a neuron netowork layer, ex. conv layer), one layer may
        contain more than one residual block

        Args:
            block: block type, basic block or bottle neck block
            out_channels: output depth channel number of this layer
            num_blocks: how many blocks per layer
            stride: the stride of the first block of this layer

        Return:
            return a resnet layer
        """
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            if stride !=1:
              layers.append(block1d(self.in_channels, stride, self.w_bits, self.a_bits))
              self.in_channels = out_channels * block2d.expansion
            elif out_channels > 1: 
              layers.append(block2d(self.in_channels, out_channels, stride, self.w_bits, self.a_bits))
              self.in_channels = out_channels * block2d.expansion
            else:
              layers.append(block1d(self.in_channels, stride, self.w_bits, self.a_bits))
              self.in_channels = out_channels * block1d.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        output = self.conv1(x)
        output = self.conv2_x(output)
        output = self.conv3_x(output)
        output = self.conv4_x(output)
        output = self.conv5_x(output)
        output = self.avg_pool(output)
        output = output.view(output.size(0), -1)
        output = self.fc(output)

        return output

def resnet18(w_bits, a_bits, **kwargs):
    """ return a ResNet 18 object
    """
    return ResNet(XBasicBlock2D, XBasicBlock1D, [4, 4, 4, 4],  w_bits, a_bits, **kwargs)
