import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
import numpy as np
from .dorefanet import *
from typing import Type, Any, Callable, List, Optional, Tuple, Union
import math

epoch = 1

stage_out_channel = [16] * 2 + [32] * 2 + [128] * 2 + [512] * 2 + [2048] * 6

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class firstconv3x3(nn.Module):
    def __init__(self, inp, oup, stride):
        super(firstconv3x3, self).__init__()

        self.conv1 = nn.Conv2d(inp, oup, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(oup)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        return out

class LearnableBias(nn.Module):
    def __init__(self, out_chn):
        super(LearnableBias, self).__init__()
        self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True)

    def forward(self, x):
        out = x + self.bias.expand_as(x)
        return out

class TypeN(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None, w_bits=1, a_bits=1):
        super(TypeN, self).__init__()

        self.stride = stride
        self.inplanes = inplanes
        self.planes = planes

        self.move1c = LearnableBias(inplanes)
        self.binary_activation = QuantizationActivation(a_bits)

        self.depthwiseconv3x3a = nn.Sequential(
            nn.Conv2d(inplanes, inplanes, kernel_size=(3, 3), stride=(stride, stride), padding=(1, 1), groups=inplanes, bias=False),
            nn.BatchNorm2d(inplanes) 
        )

        self.shortcut = nn.Sequential()

        self.relu1 = nn.Sequential(
          nn.PReLU(planes),
          nn.BatchNorm2d(planes) 
        )   

        self.move2c = LearnableBias(planes)

        if inplanes == 16 or inplanes == 32:
          self.binconv3x3a = nn.Sequential(
              QuantizationConv2d(planes, planes, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, w_bits=w_bits),
              nn.BatchNorm2d(planes)
          )
        else: 
          self.binconv3x3a = nn.Sequential(
              QuantizationConv2d(planes, planes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False, w_bits=w_bits),
              nn.BatchNorm2d(planes)
          )

        self.relu2 = nn.Sequential(
            nn.PReLU(planes),
            nn.BatchNorm2d(planes) 
        )   

    def forward(self, x):
        residual1 = x 
        out = self.move1c(x) 
        out = self.depthwiseconv3x3a(out)
        out += self.shortcut(residual1)
        out = self.relu1(out)
        residual2 = out
        out = self.move2c(out) 
        out = self.binary_activation(out)
        out = self.binconv3x3a(out)
        out += residual2
        out = self.relu2(out)
        return out


class TypeDS(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None, w_bits=1, a_bits=1):
        super(TypeDS, self).__init__()

        self.stride = stride
        self.inplanes = inplanes
        self.planes = planes

        self.binary_activation = QuantizationActivation(a_bits)

        self.move1c = LearnableBias(inplanes)

        self.depthwiseconv3x3a = nn.Sequential(
          nn.Conv2d(inplanes, inplanes, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), groups=inplanes, bias=False),
          nn.BatchNorm2d(inplanes) 
        )
        self.depthwiseconv3x3b = nn.Sequential(
          nn.Conv2d(inplanes, inplanes, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), groups=inplanes, bias=False),
          nn.BatchNorm2d(inplanes) 
        )

        self.shortcut1 = nn.Sequential(nn.AvgPool2d(kernel_size=(2, 1), stride=(2, 1)))   

        self.relu1 = nn.Sequential(
          nn.PReLU(planes),
          nn.BatchNorm2d(planes) 
        )   

        self.move2c = LearnableBias(planes)
        if inplanes == 16 or inplanes == 32:
          self.binconv3x3a = nn.Sequential(
              QuantizationConv2d(planes, planes, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, w_bits=w_bits),
              nn.BatchNorm2d(planes)
          )
        else: 
          self.binconv3x3a = nn.Sequential(
              QuantizationConv2d(planes, planes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False, w_bits=w_bits),
              nn.BatchNorm2d(planes)
          )

        self.relu2 = nn.Sequential(
          nn.PReLU(planes),
          nn.BatchNorm2d(planes) 
        )   

        self.move3c = LearnableBias(planes)

        self.depthwiseconv3x3c = nn.Sequential(
          nn.Conv2d(planes, planes, kernel_size=(3, 3), stride=(1, 2), padding=(1, 1), groups=planes, bias=False),
          nn.BatchNorm2d(planes) 
        )
        
        self.shortcut2 = nn.Sequential(nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2)))   

        self.relu3 = nn.Sequential(
          nn.PReLU(planes),
          nn.BatchNorm2d(planes) 
        )   

        self.move4c = LearnableBias(planes)
        if inplanes == 16 or inplanes == 32:
          self.binconv3x3b = nn.Sequential(
              QuantizationConv2d(planes, planes, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, w_bits=w_bits),
              nn.BatchNorm2d(planes)
          )
        else: 
          self.binconv3x3b = nn.Sequential(
              QuantizationConv2d(planes, planes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False, w_bits=w_bits),
              nn.BatchNorm2d(planes)
          )
        self.relu4 = nn.Sequential(
          nn.PReLU(planes),
          nn.BatchNorm2d(planes) 
        )   


    def forward(self, x):
        residual1 = x 
        out = self.move1c(x) 
        out_1 = self.depthwiseconv3x3a(out)
        out_2 = self.depthwiseconv3x3b(out)
        out_1 += self.shortcut1(residual1)
        out_2 += self.shortcut1(residual1)
        out = torch.cat([out_1, out_2], dim=1)
        out = self.relu1(out)
        
        residual2 = out
        out = self.move2c(out) 
        out = self.binary_activation(out)
        out = self.binconv3x3a(out)
        out += residual2
        out = self.relu2(out)

        residual3 = out 
        out = self.move3c(out) 
        out = self.depthwiseconv3x3c(out)
        out += self.shortcut2(residual3)
        out = self.relu3(out)
        
        residual4 = out
        out = self.move4c(out) 
        out = self.binary_activation(out)
        out = self.binconv3x3b(out)
        out += residual4
        out = self.relu4(out)

        return out


class TypeQS(nn.Module):
    def __init__(self, inplanes, interim, planes, stride=1, downsample=None, w_bits=1, a_bits=1):
        super(TypeQS, self).__init__()

        self.stride = stride
        self.inplanes = inplanes
        self.planes = planes

        self.move1c = LearnableBias(inplanes)
        self.binary_activation = QuantizationActivation(a_bits)

        self.move1c = LearnableBias(inplanes)

        self.depthwiseconv3x3a = nn.Sequential(
          nn.Conv2d(inplanes, inplanes, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), groups=inplanes, bias=False),
          nn.BatchNorm2d(inplanes) 
        )
        self.depthwiseconv3x3b = nn.Sequential(
          nn.Conv2d(inplanes, inplanes, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), groups=inplanes, bias=False),
          nn.BatchNorm2d(inplanes) 
        )

        self.shortcut1 = nn.Sequential(nn.AvgPool2d(kernel_size=(2, 1), stride=(2, 1)))   

        self.relu1 = nn.Sequential(
          nn.PReLU(interim),
          nn.BatchNorm2d(interim) 
        )   

        self.move2c = LearnableBias(interim)
        if inplanes == 16 or inplanes == 32:
          self.binconv3x3a = nn.Sequential(
              QuantizationConv2d(interim, interim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, w_bits=w_bits),
              nn.BatchNorm2d(interim)
          )
        else: 
          self.binconv3x3a = nn.Sequential(
              QuantizationConv2d(interim, interim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False, w_bits=w_bits),
              nn.BatchNorm2d(interim)
          )

        self.relu2 = nn.Sequential(
          nn.PReLU(interim),
          nn.BatchNorm2d(interim) 
        )   

        self.move3c = LearnableBias(interim)

        self.depthwiseconv3x3c = nn.Sequential(
          nn.Conv2d(interim, interim, kernel_size=(3, 3), stride=(1, 2), padding=(1, 1), groups=interim, bias=False),
          nn.BatchNorm2d(interim) 
        )
        self.depthwiseconv3x3d = nn.Sequential(
          nn.Conv2d(interim, interim, kernel_size=(3, 3), stride=(1, 2), padding=(1, 1), groups=interim, bias=False),
          nn.BatchNorm2d(interim) 
        )

        self.shortcut2 = nn.Sequential(nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2)))   

        self.relu3 = nn.Sequential(
          nn.PReLU(planes),
          nn.BatchNorm2d(planes) 
        )   

        self.move4c = LearnableBias(planes)
        if inplanes == 16 or inplanes == 32:
          self.binconv3x3b = nn.Sequential(
              QuantizationConv2d(planes, planes, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, w_bits=w_bits),
              nn.BatchNorm2d(planes)
          )
        else: 
          self.binconv3x3b = nn.Sequential(
              QuantizationConv2d(planes, planes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False, w_bits=w_bits),
              nn.BatchNorm2d(planes)
          )
        self.relu4 = nn.Sequential(
          nn.PReLU(planes),
          nn.BatchNorm2d(planes) 
        )   


    def forward(self, x):
        residual1 = x 
        out = self.move1c(x) 
        out_1 = self.depthwiseconv3x3a(out)
        out_2 = self.depthwiseconv3x3b(out)
        out_1 += self.shortcut1(residual1)
        out_2 += self.shortcut1(residual1)
        out = torch.cat([out_1, out_2], dim=1)
        out = self.relu1(out)
        
        residual2 = out
        out = self.move2c(out) 
        out = self.binary_activation(out)
        out = self.binconv3x3a(out)
        out += residual2
        out = self.relu2(out)

        residual3 = out 
        out = self.move3c(out) 
        out_1 = self.depthwiseconv3x3c(out)
        out_2 = self.depthwiseconv3x3d(out)
        out_1 += self.shortcut2(residual3)
        out_2 += self.shortcut2(residual3)
        out = torch.cat([out_1, out_2], dim=1)
        out = self.relu3(out)
        
        residual4 = out
        out = self.move4c(out) 
        out = self.binary_activation(out)
        out = self.binconv3x3b(out)
        out += residual4
        out = self.relu4(out)

        return out


class qsb_net_large(nn.Module):
    def __init__(self, w_bits, a_bits, num_classes=1000, ):
        super(qsb_net_large, self).__init__()
        self.feature = []
        for i in range(len(stage_out_channel)):
            if i == 0:
                self.feature.append(firstconv3x3(3, stage_out_channel[i], 2))
            elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] == 32 : 
                self.feature.append(TypeDS(stage_out_channel[i-1], stage_out_channel[i], 2, w_bits=w_bits, a_bits=a_bits))
            elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] == 128 : 
                self.feature.append(TypeQS(stage_out_channel[i-1], int(stage_out_channel[i]/2), stage_out_channel[i], 2, w_bits=w_bits, a_bits=a_bits))
            elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] == 512 : 
                self.feature.append(TypeQS(stage_out_channel[i-1], int(stage_out_channel[i]/2), stage_out_channel[i], 2, w_bits=w_bits, a_bits=a_bits))
            elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] == 2048 : 
                self.feature.append(TypeQS(stage_out_channel[i-1], int(stage_out_channel[i]/2), stage_out_channel[i], 2, w_bits=w_bits, a_bits=a_bits))
            else: 
                self.feature.append(TypeN(stage_out_channel[i-1], stage_out_channel[i], 1, w_bits=w_bits, a_bits=a_bits))

        self.feature = nn.Sequential(*self.feature)

        self.pool1 = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.feature(x)

        x = x.mean([2, 3])

        x = self.fc(x)

        return x

def FB3(pretrained=False):

  model = qsb_net_large(1, 1, 1000)
   
  if pretrained:
      from collections import OrderedDict
      pretrained_model = torch.load("model_best_1_1_Mar27_2024qsb_net_large_seeds42.pth.tar", map_location='cpu')
      new_state_dict = OrderedDict()
      for n, v in pretrained_model['state_dict'].items():
              name = n.replace("module.","")
              new_state_dict[name] = v
      model.load_state_dict(new_state_dict)
      print("Complete load pre-trained model")

  from torchinfo import summary
  summary(model, (1, 3, 224, 224))

  return model
