import torch
import torch.nn as nn
import torch.nn.functional as F
from quant_layers.quant_activation import q_act
from quant_layers.quant_activation import cyc_relu
from quant_layers.mod_conv import Mod_conv2d_try
from quant_layers.mod_linear import Mod_Linear
import numpy as np

class AlexNet(nn.Module):
    def __init__(self, c_bits, num_classes=1000):
        super(AlexNet, self).__init__()
        self.qact = q_act(8)
        self.mp = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv0 = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2,bias=False)
        self.bn0 = nn.BatchNorm2d(64)
        self.conv1 = Mod_conv2d_try(4, c_bits, 64, 192, kernel_size=5, padding=2, bias=False)
        self.bn1 = nn.BatchNorm2d(192)
        self.conv2 = Mod_conv2d_try(4, c_bits, 192, 384, kernel_size=3, padding=1,bias=False)
        self.bn2 = nn.BatchNorm2d(384)
        self.conv3 = Mod_conv2d_try(4, c_bits, 384, 256, kernel_size=3, padding=1,bias=False)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = Mod_conv2d_try(4, c_bits, 256, 256, kernel_size=3, padding=1,bias=False)
        self.bn4 = nn.BatchNorm2d(256)
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.linear1 = Mod_Linear(4, 1, 256 * 6 * 6, 4096, False)
        self.bn5 = nn.BatchNorm1d(num_features=4096)
        self.linear2 = Mod_Linear(4, 1, 4096, 4096, False)
        self.bn6 = nn.BatchNorm1d(num_features=4096)
        self.classifier = nn.Linear(4096, num_classes)
        self.alpha = 1.0
        self.cyc = cyc_relu(8)

    def forward(self, x):
        outrange = 0
        ofrate = []
        oflinear = []
        out = self.conv0(x)
        out = self.bn0(out)
        out = F.relu6(out)
        s0 = 5
        out = self.qact(out, s0)
        out = self.mp(out)

        out, ofrate1, s1 = self.conv1(out, s0, self.alpha)
        ofrate.append(ofrate1)
        outrange += (F.relu(torch.abs(out/s1) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s1)
        out = self.bn1(out)
        out = F.relu6(out)
        out = self.mp(out)
        s1 = 0.4

        out = self.qact(out, s1)
        out, ofrate2, s2 = self.conv2(out, s1, self.alpha)
        ofrate.append(ofrate2)
        outrange += (F.relu(torch.abs(out/s2) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s2)
        out = self.bn2(out)
        out = F.relu6(out)
        s2 = 0.4

        out = self.qact(out, s2)
        out, ofrate3, s3 = self.conv3(out, s2, self.alpha)
        ofrate.append(ofrate3)
        outrange += (F.relu(torch.abs(out/s3) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s3)
        out = self.bn3(out)
        out = F.relu6(out)
        s3 = 0.4

        out = self.qact(out, s3)
        out, ofrate4, s4 = self.conv4(out, s3, self.alpha)
        ofrate.append(ofrate4)
        outrange += (F.relu(torch.abs(out/s4) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, s4)
        out = self.bn4(out)
        out = F.relu6(out)
        out = self.mp(out)
        s4 = 0.8

        out = self.qact(out, s4)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out, oflinear1, sl = self.linear1(out, s4, self.alpha)
        oflinear.append(oflinear1)
        outrange += (F.relu(torch.abs(out/sl) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, sl)
        out = self.bn5(out)
        out = F.relu6(out)
        sl = 0.8

        out = self.qact(out, sl)
        out, oflinear2, sl1 = self.linear2(out, sl, self.alpha)
        oflinear.append(oflinear2)
        outrange += (F.relu(torch.abs(out/sl1) - 2**7)).view(out.shape[0], -1).mean(1)
        out = self.cyc(out, sl1)
        out = self.bn6(out)
        out = F.relu6(out)  

        out = self.classifier(out)
        return out, outrange, torch.tensor([np.sum(ofrate)]*out.shape[0]).cuda(), torch.tensor([oflinear]*out.shape[0]).cuda()

# test()
