'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from quant_layers.quant_activation import q_act
from quant_layers.quant_activation import cyc_relu
from quant_layers.mod_conv_carry import Mod_conv2d_try
from quant_layers.mod_linear import Mod_Linear
from quant_layers.functions import no_grad_mul
import pdb
import numpy as np

class VGG_7_carry(nn.Module):
    def __init__(self, vgg_name):
        super(VGG_7_carry, self).__init__()
        self.conv1 = nn.Conv2d(3, 128, kernel_size=3, padding=0,bias=False)
        self.bn1 = nn.BatchNorm2d(128)
        self.conv2 = Mod_conv2d_try(4, 1, 128, 128, 28, kernel_size=3, padding=0,bias=False, carry=True) 
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = Mod_conv2d_try(4, 1, 128, 256, 12, kernel_size=3, padding=0,bias=False, carry=True) 
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = Mod_conv2d_try(4, 1, 256, 256, 10, kernel_size=3, padding=0,bias=False, carry=True) 
        self.bn4 = nn.BatchNorm2d(256)
        self.conv5 = Mod_conv2d_try(4, 1, 256, 512, 3, kernel_size=3, padding=0,bias=False, carry=True) 
        self.bn5 = nn.BatchNorm2d(512)
        self.conv6 = Mod_conv2d_try(4, 1, 512, 512, 3, kernel_size=3, padding=1,bias=False, carry=True) 
        self.bn6 = nn.BatchNorm2d(512)
        self.classifier = nn.Linear(512, 10, True)
        self.mp = nn.MaxPool2d(kernel_size=2, stride=2)
        self.qact = q_act(8)
        self.cyc = cyc_relu(8)
        self.cyc7 = cyc_relu(7)
        self.alpha = 1.#0.0625

    def forward(self, x):
        outrange = 0
        std_carry = 0
        num = 0
        inputs = []
        outputs = []
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu6(out)
        s1 = 1.7#(out.max().item() - out.min().item())/ 2**4 #+ 2**(-15)
        out = self.qact(out, s1)
        inputs.append(out)
        out, ofrate2, s2, c2 = self.conv2(out, s1, self.alpha, 0)
        outputs.append(out/s2)
        outrange += (F.relu(torch.abs(out/s2) - 2**7)).mean()
        std_carry += c2
        out = self.cyc7(out, s2) 
        out = self.bn2(out)
        out = F.relu6(out)
        s2 = 2.2#(out.max().item() - out.min().item())/ 2**4 #+ 2**(-15)
        out = self.qact(out, s2)
        out = self.mp(out)
        inputs.append(out)
        out, ofrate3, s3, c3 = self.conv3(out, s2, self.alpha, 1)
        outputs.append(out/s3)
        outrange += (F.relu(torch.abs(out/s3) - 2**7)).mean()
        std_carry += c3
        out = self.cyc(out, s3) 
        out = self.bn3(out)
        out = F.relu6(out)
        s3 = 2.3#(out.max().item() - out.min().item())/ 2**4 #+ 2**(-15)
        out = self.qact(out, s3)
        inputs.append(out)
        out, ofrate4, s4, c4 = self.conv4(out, s3, self.alpha, 0)
        outputs.append(out/s4)
        outrange += (F.relu(torch.abs(out/s4) - 2**7)).mean()
        std_carry += c4#
        out = self.cyc7(out, s4) 
        out = self.bn4(out)
        out = F.relu6(out)
        s4 = 2.6#(out.max().item() - out.min().item())/ 2**4 #+ 2**(-15)
        out = self.qact(out, s4)
        out = self.mp(out)
        inputs.append(out)
        out, ofrate5, s5, c5 = self.conv5(out, s4, self.alpha, 1)
        outputs.append(out/s5)
        outrange += (F.relu(torch.abs(out/s5) - 2**7)).mean()
        std_carry += c5
        out = self.cyc(out, s5) 
        out = self.bn5(out)
        out = F.relu6(out)
        s5 = 1.8#(out.max().item() - out.min().item())/ 2**4 #+ 2**(-15)
        out = self.qact(out, s5)
        inputs.append(out)
        out, ofrate6, s6, c6 = self.conv6(out, s5, self.alpha, 0)
        outputs.append(out/s6)
        std_carry += c6
        out = self.cyc7(out, s6) 
        out = self.bn6(out)
        out = F.relu6(out)
        s6 = 1.#(out.max().item() - out.min().item())/ 2**8 #+ 2**(-15)
        out = self.mp(out)
        out = out.view(out.size(0), -1)
        oflinear = 0
        out = self.classifier(out)
        return out, outrange/5, ofrate2+ofrate3+ofrate4+ofrate5+ofrate6, oflinear, std_carry/5

