'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from quant_layers.quant_activation import q_act
from quant_layers.quant_activation import cyc_relu
from quant_layers.mod_conv import Mod_conv2d_try
from quant_layers.mod_linear import Mod_Linear
from quant_layers.functions import no_grad_mul
import pdb
import numpy as np

class Res18_mod(nn.Module):
    def __init__(self, name, c_bits):
        super(Res18_mod, self).__init__()
        self.conv0 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)
        self.bn0 = nn.BatchNorm2d(64)
        self.conv1 = Mod_conv2d_try(4, c_bits, 64, 64, kernel_size=3, stride=1, padding=1, bias=False) 
        self.bn1 = nn.BatchNorm2d(64)
        self.conv1_sc = nn.Sequential()
        self.conv2 = Mod_conv2d_try(4, c_bits, 64, 64, kernel_size=3, stride=1, padding=1,bias=False) 
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = Mod_conv2d_try(4, c_bits, 64, 64, kernel_size=3, stride=1, padding=1,bias=False) 
        self.bn3 = nn.BatchNorm2d(64)
        self.conv3_sc = nn.Sequential()
        self.conv4 = Mod_conv2d_try(4, c_bits, 64, 64, kernel_size=3, stride=1, padding=1,bias=False) 
        self.bn4 = nn.BatchNorm2d(64)
        self.conv5 = Mod_conv2d_try(4, c_bits, 64, 128, kernel_size=3, stride=2, padding=1,bias=False) 
        self.bn5 = nn.BatchNorm2d(128)
        self.conv6 = Mod_conv2d_try(4, c_bits, 128, 128, kernel_size=3, stride=1, padding=1,bias=False) 
        self.bn6 = nn.BatchNorm2d(128)
        self.conv5_sc = Mod_conv2d_try(4, c_bits, 64, 128, kernel_size=1, stride=2, bias=False) 
        self.bn5_sc = nn.BatchNorm2d(128)
        self.conv7 = Mod_conv2d_try(4, c_bits, 128, 128, kernel_size=3, stride=1, padding=1,bias=False) 
        self.bn7 = nn.BatchNorm2d(128)
        self.conv7_sc = nn.Sequential()
        self.conv8 = Mod_conv2d_try(4, c_bits, 128, 128, kernel_size=3, stride=1, padding=1,bias=False) 
        self.bn8 = nn.BatchNorm2d(128)
        self.conv9 = Mod_conv2d_try(4, c_bits, 128, 256, kernel_size=3, stride=2, padding=1,bias=False) 
        self.bn9 = nn.BatchNorm2d(256)
        self.conv10 = Mod_conv2d_try(4, c_bits, 256, 256, kernel_size=3, stride=1, padding=1,bias=False) 
        self.bn10 = nn.BatchNorm2d(256)
        self.conv9_sc = Mod_conv2d_try(4, c_bits, 128, 256, kernel_size=1, stride=2, bias=False) 
        self.bn9_sc = nn.BatchNorm2d(256)
        self.conv11 = Mod_conv2d_try(4, c_bits, 256, 256, kernel_size=3, stride=1, padding=1,bias=False) 
        self.bn11 = nn.BatchNorm2d(256)
        self.conv11_sc = nn.Sequential()
        self.conv12 = Mod_conv2d_try(4, c_bits, 256, 256, kernel_size=3, stride=1, padding=1,bias=False) 
        self.bn12 = nn.BatchNorm2d(256)
        self.conv13 = Mod_conv2d_try(4, c_bits, 256, 512, kernel_size=3, stride=2, padding=1,bias=False) 
        self.bn13 = nn.BatchNorm2d(512)
        self.conv14 = Mod_conv2d_try(4, c_bits, 512, 512, kernel_size=3, stride=1, padding=1,bias=False) 
        self.bn14 = nn.BatchNorm2d(512)
        self.conv13_sc = Mod_conv2d_try(4, c_bits, 256, 512, kernel_size=1, stride=2, bias=False) 
        self.bn13_sc = nn.BatchNorm2d(512)
        self.conv15 = Mod_conv2d_try(4, c_bits, 512, 512, kernel_size=3, stride=1, padding=1,bias=False) 
        self.bn15 = nn.BatchNorm2d(512)
        self.conv15_sc = nn.Sequential()
        self.conv16 = Mod_conv2d_try(4, c_bits, 512, 512, kernel_size=3, stride=1, padding=1,bias=False) 
        self.bn16 = nn.BatchNorm2d(512)
        self.classifier = nn.Linear(512, 1000)
        self.mp = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.qact = q_act(8)
        self.cyc = cyc_relu(8)
        self.alpha = 1.

        ##init
        for m in self.modules():
            if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


    def forward(self, x):
        outrange = 0
        ofrate = []
        out = self.conv0(x)
        out = self.bn0(out)
        out = F.relu6(out) 
        s0 = 0.8#(out.max().item() - out.min().item())/ (2**8 - 1)
        out = self.qact(out, s0)
        out = self.mp(out)

        ##residual block 1
        ###shortcut
        out_sc = self.conv1_sc(out)
        ###2 convs
        out, ofrate1, s1 = self.conv1(out, s0, self.alpha)
        ofrate.append(ofrate1)
        outrange += (F.relu(torch.abs(out/s1) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s1)
        out = self.bn1(out)
        out = F.relu6(out) 
        s1 = 0.8#(out.max().item() - out.min().item())/ (2**8 - 1) #+ 2**(-15)
        out = self.qact(out, s1)
        out, ofrate2, s2 = self.conv2(out, s1, self.alpha)
        ofrate.append(ofrate2)
        outrange += (F.relu(torch.abs(out/s2) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s2)
        out = self.bn2(out)
        ###addition + ReLU + Re-quantize
        out += out_sc
        out = F.relu6(out) 
        s2 = 0.8#(out.max().item() - out.min().item())/ (2**8 - 1) #+ 2**(-15)
        out = self.qact(out, s2)

        ##residual block 2
        ###shortcut
        out_sc = self.conv3_sc(out)
        ###2 convs
        out, ofrate3, s3 = self.conv3(out, s2, self.alpha)
        ofrate.append(ofrate3)
        outrange += (F.relu(torch.abs(out/s3) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s3)
        out = self.bn3(out)
        out = F.relu6(out) 
        s3 = 0.8#(out.max().item() - out.min().item())/ (2**8 - 1) #+ 2**(-15)
        out = self.qact(out, s3)
        out, ofrate4, s4 = self.conv4(out, s3, self.alpha)
        ofrate.append(ofrate4)
        outrange += (F.relu(torch.abs(out/s4) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s4) 
        out = self.bn4(out)
        ###addition + ReLU + Re-quantize
        out += out_sc
        out = F.relu6(out) 
        s4 = 0.8#(out.max().item() - out.min().item())/ (2**8 - 1) #+ 2**(-15)
        out = self.qact(out, s4)

        ##residual block 3
        ###shortcut
        out_sc, ofratesc, ssc = self.conv5_sc(out, s4, self.alpha)
        ofrate.append(ofratesc)
        outrange += (F.relu(torch.abs(out_sc/ssc) - 2**7)).view(out.shape[0], -1).mean(1)
        out_sc = self.cyc(out_sc, ssc)
        out_sc = self.bn5_sc(out_sc)
        ###2 convs
        out, ofrate5, s5 = self.conv5(out, s4, self.alpha)
        ofrate.append(ofrate5)
        outrange += (F.relu(torch.abs(out/s5) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s5)
        out = self.bn5(out)
        out = F.relu6(out) 
        s5 = 0.8#(out.max().item() - out.min().item())/ (2**8 - 1) #+ 2**(-15)
        out = self.qact(out, s5)
        out, ofrate6, s6 = self.conv6(out, s5, self.alpha)
        ofrate.append(ofrate6)
        outrange += (F.relu(torch.abs(out/s6) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s6)
        out = self.bn6(out)
        ###addition + ReLU + Re-quantize
        out += out_sc
        out = F.relu6(out) 
        s6 = 0.8#(out.max().item() - out.min().item())/ (2**8 - 1) #+ 2**(-15)
        out = self.qact(out, s6)

        ##residual block 4
        ###shortcut
        out_sc = self.conv7_sc(out)
        ###2 convs
        out, ofrate7, s7 = self.conv7(out, s6, self.alpha)
        ofrate.append(ofrate7)
        outrange += (F.relu(torch.abs(out/s7) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s7)
        out = self.bn7(out)
        out = F.relu6(out) 
        s7 = 0.8#(out.max().item() - out.min().item())/ (2**8 - 1) #+ 2**(-15)
        out = self.qact(out, s7)
        out, ofrate8, s8 = self.conv8(out, s7, self.alpha)
        ofrate.append(ofrate8)
        outrange += (F.relu(torch.abs(out/s8) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s8)
        out = self.bn8(out)
        ###addition + ReLU + Re-quantize
        out += out_sc
        out = F.relu6(out) 
        s8 = 0.8#(out.max().item() - out.min().item())/ (2**8 - 1) #+ 2**(-15)
        out = self.qact(out, s8)

        ##residual block 5
        ###shortcut
        out_sc, ofratesc, ssc = self.conv9_sc(out, s8, self.alpha)
        ofrate.append(ofratesc)
        outrange += (F.relu(torch.abs(out_sc/ssc) - 2**7)).view(out.shape[0], -1).mean(1)
        out_sc = self.cyc(out_sc, ssc)
        out_sc = self.bn9_sc(out_sc)
        ###2 convs
        out, ofrate9, s9 = self.conv9(out, s8, self.alpha)
        ofrate.append(ofrate9)
        outrange += (F.relu(torch.abs(out/s9) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s9)
        out = self.bn9(out)
        out = F.relu6(out) 
        s9 = 0.8#(out.max().item() - out.min().item())/ (2**8 - 1) #+ 2**(-15)
        out = self.qact(out, s9)
        out, ofrate10, s10 = self.conv10(out, s9, self.alpha)
        ofrate.append(ofrate10)
        outrange += (F.relu(torch.abs(out/s10) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s10)
        out = self.bn10(out)
        ###addition + ReLU + Re-quantize
        out += out_sc
        out = F.relu6(out) 
        s10 = 0.8#(out.max().item() - out.min().item())/ (2**8 - 1) #+ 2**(-15)
        out = self.qact(out, s10)

        ##residual block 6
        ###shortcut
        out_sc = self.conv11_sc(out)
        ###2 convs
        out, ofrate11, s11 = self.conv11(out, s10, self.alpha)
        ofrate.append(ofrate11)
        outrange += (F.relu(torch.abs(out/s11) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s11)
        out = self.bn11(out)
        out = F.relu6(out) 
        s11 = 0.8#(out.max().item() - out.min().item())/ (2**8 - 1) #+ 2**(-15)
        out = self.qact(out, s11)
        out, ofrate12, s12 = self.conv12(out, s11, self.alpha)
        ofrate.append(ofrate12)
        outrange += (F.relu(torch.abs(out/s12) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s12)
        out = self.bn12(out)
        ###addition + ReLU + Re-quantize
        out += out_sc
        out = F.relu6(out) 
        s12 = 0.8#(out.max().item() - out.min().item())/ (2**8 - 1) #+ 2**(-15)
        out = self.qact(out, s12)

        ##residual block 7
        ###shortcut
        out_sc, ofratesc, ssc = self.conv13_sc(out, s12, self.alpha)
        ofrate.append(ofratesc)
        outrange += (F.relu(torch.abs(out_sc/ssc) - 2**7)).view(out.shape[0], -1).mean(1)
        out_sc = self.cyc(out_sc, ssc)
        out_sc = self.bn13_sc(out_sc)
        ###2 convs
        out, ofrate13, s13 = self.conv13(out, s12, self.alpha)
        ofrate.append(ofrate13)
        outrange += (F.relu(torch.abs(out/s13) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s13)
        out = self.bn13(out)
        out = F.relu6(out) 
        s13 = 0.8#(out.max().item() - out.min().item())/ (2**8 - 1) #+ 2**(-15)
        out = self.qact(out, s13)
        out, ofrate14, s14 = self.conv14(out, s13, self.alpha)
        ofrate.append(ofrate14)
        outrange += (F.relu(torch.abs(out/s14) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s14)
        out = self.bn14(out)
        ###addition + ReLU + Re-quantize
        out += out_sc
        out = F.relu6(out) 
        s14 = 0.8#(out.max().item() - out.min().item())/ (2**8 - 1) #+ 2**(-15)
        out = self.qact(out, s14)

        ##residual block 8
        ###shortcut
        out_sc = self.conv15_sc(out)
        ###2 convs
        out, ofrate15, s15 = self.conv15(out, s14, self.alpha)
        ofrate.append(ofrate15)
        outrange += (F.relu(torch.abs(out/s15) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s15)
        out = self.bn15(out)
        out = F.relu6(out) 
        s15 = 0.8#(out.max().item() - out.min().item())/ (2**8 - 1)  #+ 2**(-15)
        out = self.qact(out, s15)
        out, ofrate16, s16 = self.conv16(out, s15, self.alpha)
        ofrate.append(ofrate16)
        outrange += (F.relu(torch.abs(out/s16) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s16)
        out = self.bn16(out)
        ###addition + ReLU + Re-quantize
        out += out_sc
        out = F.relu6(out) 

        ##Liner layer
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        oflinear = 0
        out = self.classifier(out)
        return out, outrange, torch.tensor([np.sum(ofrate)]*out.shape[0]).cuda(), torch.tensor([oflinear]*out.shape[0]).cuda()

