import torch
import torchvision
from torchvision.models import resnet18
from torch import nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.autograd.function import  Function
from torch.autograd import Variable
from torch.autograd.function import InplaceFunction
from resnet_cifar10 import *


CRITERIONS  = {'mnist': F.cross_entropy,
               'mnist_hetero': F.cross_entropy,
               'mnist_hetero_qnn': F.cross_entropy,
               'mnist_qnn': F.cross_entropy,
               'fashion mnist': nn.CrossEntropyLoss(),
               'cifar 10': nn.CrossEntropyLoss(),
               'cifar_10_hetero': nn.CrossEntropyLoss(),
               'cifar_10_hetero_qnn': nn.CrossEntropyLoss(),
               'tinyimagenet': nn.CrossEntropyLoss(),
               'celeba': nn.CrossEntropyLoss()}

OPTIMIZERS = {'mnist': torch.optim.SGD,
              'mnist_hetero': torch.optim.SGD,
              'mnist_hetero_qnn': torch.optim.SGD,
              'mnist_qnn': torch.optim.SGD,
              'fashion mnist': torch.optim.Adam,
              'cifar 10': torch.optim.SGD,
              'cifar_10_hetero': torch.optim.SGD,
              'cifar_10_hetero_qnn': torch.optim.SGD,
              'tinyimagenet': torch.optim.SGD,
              'celeba': torch.optim.SGD}

class UniformQuantizeSawb(InplaceFunction):

    @staticmethod
    def forward(ctx, input,c1,c2,Qp, Qn ):

        output = input.clone()

        with torch.no_grad():
            clip = (c1*torch.sqrt(torch.mean(input**2))) - (c2*torch.mean(input.abs()))
            scale = 2*clip / (Qp - Qn)
            output.div_(scale)
            output.clamp_(Qn, Qp).round_()
            output.mul_(scale)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # straight-through estimator
        grad_input = grad_output
        return grad_input, None, None, None, None
    
class GradStochasticClippingQ(Function):

    @staticmethod
    def forward(ctx, x, quantizeBwd,repeatBwd):
        ctx.save_for_backward(torch.tensor(quantizeBwd),torch.tensor(repeatBwd))
        return x

    @staticmethod
    def backward(ctx, grad_output):
        quant,repeatBwd = ctx.saved_tensors
        if quant:
            out = []
            for i in range(repeatBwd):

                mx = torch.max(grad_output)
                bits = 3
                alpha = mx / 2**(2**bits-1)

                alphaEps = alpha * torch.rand(grad_output.shape,device=grad_output.device)

                grad_abs = grad_output.abs()

                grad_input = torch.where(grad_abs < alpha , alpha*torch.sign(grad_output), grad_output)
                grad_input = torch.where(grad_abs < alphaEps, torch.tensor([0], dtype=torch.float32,device=grad_output.device), grad_input)

                grad_inputQ = grad_input.clone()
                noise = (2 ** torch.floor(torch.log2((grad_inputQ.abs() / alpha)) )) * grad_inputQ.new(grad_inputQ.shape).uniform_(-0.5,0.5)
                grad_inputQ = 2 ** torch.floor(torch.log2( ((grad_inputQ.abs() / alpha) + noise) *4/3 ) ) * alpha

                grad_inputQ =  torch.sign(grad_input) * torch.where(grad_inputQ < (alpha * (2 ** torch.floor(torch.log2(((grad_input.abs()/alpha)) )))),alpha *  (2 ** torch.floor(torch.log2(((grad_input.abs()/alpha)  ) ))), grad_inputQ)
                grad_inputQ = torch.where(grad_input == 0, torch.tensor([0], dtype=torch.float,device=grad_output.device), grad_inputQ)

                out.append(grad_inputQ)
            grad_input = sum(out) / repeatBwd

        else:

            grad_input = grad_output
        return grad_input,None,None

class Linear_LUQ(nn.Linear):
    """docstring for Conv2d_BF16."""

    def __init__(self, *args, **kwargs):
        super(Linear_LUQ, self).__init__(*args, **kwargs)
        self.abits = 4
        self.wbits = 4

        self.QnW = -2 ** (self.wbits - 1)
        self.QpW = 2 ** (self.wbits - 1)
        self.QnA = 0
        self.QpA = 2 ** self.abits - 1

        self.quantizeFwd = True
        self.quantizeBwd = True #True #False

        self.c1 = 12.1
        self.c2 = 12.2

        self.stochastic = True
        self.repeatBwd = 1
    def forward(self, input):
        if self.quantizeFwd:
            w_q = UniformQuantizeSawb.apply(self.weight,self.c1,self.c2,self.QpW,self.QnW)

            if torch.min(input) < 0:
                self.QnA = -2 ** (self.abits - 1)

            qinput = UniformQuantizeSawb.apply(input,self.c1,self.c2,self.QpA,self.QnA)

            #all
            output = F.linear(qinput, w_q, self.bias)

        else:
            output = F.linear(input, self.weight, self.bias)

        output = GradStochasticClippingQ.apply(output, self.quantizeBwd,self.repeatBwd)
        return output
    
class Conv2d_LUQ(nn.Conv2d):
    """docstring for Conv2d_BF16."""

    def __init__(self, *args, **kwargs):
        super(Conv2d_LUQ, self).__init__(*args, **kwargs)
        self.abits = 4
        self.wbits = 4

        self.QnW = -2 ** (self.wbits - 1)
        self.QpW = 2 ** (self.wbits - 1)
        self.QnA = 0
        self.QpA = 2 ** self.abits - 1

        self.quantizeFwd = False
        self.quantizeBwd = False

        self.c1 = 12.1
        self.c2 = 12.2

        self.stochastic = True
        self.repeatBwd = 1
    def forward(self, input):
        if self.quantizeFwd:
            w_q = UniformQuantizeSawb.apply(self.weight,self.c1,self.c2,self.QpW,self.QnW)

            if torch.min(input) < 0:
                self.QnA = -2 ** (self.abits - 1)

            qinput = UniformQuantizeSawb.apply(input,self.c1,self.c2,self.QpA,self.QnA)

            #all
            output = F.conv2d(qinput, w_q, self.bias, self.stride,
                              self.padding, self.dilation, self.groups)

        else:
            output = F.conv2d(input, self.weight, self.bias, self.stride,
                              self.padding, self.dilation, self.groups)

        output = GradStochasticClippingQ.apply(output, self.quantizeBwd,self.repeatBwd)
        return output

def get_criterion(dataset_name):
    return CRITERIONS[dataset_name]

def get_optimizer(dataset_name):
    return OPTIMIZERS[dataset_name]

def get_model(dataset_name):
    if dataset_name == "mnist":
        return MnistNet()    
    elif dataset_name == "mnist_hetero":
        return MnistNet()
    elif dataset_name == "mnist_hetero_qnn":
        return MnistNet_QNN()
    elif dataset_name == "mnist_qnn":
        return MnistNet_QNN()
    elif dataset_name == "fashion mnist":
        return FashionMnistNet()
    elif dataset_name == "cifar 10":
        return resnet20_cifar()
    elif dataset_name == "cifar_10_hetero":
        return resnet20_cifar()
    elif dataset_name == "cifar_10_hetero_qnn":
        return resnet20_cifar_qnn()
    elif dataset_name == 'celeba':
        model = resnet18()
        model.load_state_dict(torch.load('resnet18-5c106cde.pth'))
        new_fc = nn.Linear(512, 2, True)
        model.fc = new_fc
        return  model
    elif dataset_name == 'tinyimagenet':
        return resnet20_tinyimagenet()

class MnistNet(nn.Module):
    def __init__(self, input_size=784, hidden_size=32, num_classes=10):
        super(MnistNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) 
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)  
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out
    
class MnistNet_QNN(nn.Module):
    def __init__(self, input_size=784, hidden_size=32, num_classes=10):
        super(MnistNet_QNN, self).__init__()
        self.fc1 = Linear_LUQ(input_size, hidden_size) 
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)  
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

class FashionMnistNet(nn.Module):
    
    def __init__(self):
        super(FashionMnistNet, self).__init__()
        
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.fc1 = nn.Linear(in_features=64*6*6, out_features=600)
        self.drop = nn.Dropout2d(0.25)
        self.fc2 = nn.Linear(in_features=600, out_features=120)
        self.fc3 = nn.Linear(in_features=120, out_features=10)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.drop(out)
        out = self.fc2(out)
        out = self.fc3(out)
        
        return out
    
class CelebaNet(nn.Module):
    
        def __init__(self):
            super(CelebaNet, self).__init__()

            self.layer1 = nn.Sequential(
                nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
                nn.BatchNorm2d(32),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2)
            )

            self.layer2 = nn.Sequential(
                nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.MaxPool2d(2)
            )

            self.fc1 = nn.Linear(in_features=193600, out_features=10000)
            self.drop = nn.Dropout2d(0.25)
            self.fc2 = nn.Linear(in_features=10000, out_features=500)
            self.fc3 = nn.Linear(in_features=500, out_features=1)

        def forward(self, x):
            out = self.layer1(x)
            out = self.layer2(out)
            out = out.view(out.size(0), -1)
            print(out.shape)
            out = self.fc1(out)
            out = self.drop(out)
            out = self.fc2(out)
            out = self.fc3(out)

            return out
#     def __init__(self):
#         super(CelebaNet, self).__init__()
        
#         self.layer_1 = self.make_block()
#         self.layer_2 = self.make_block()
#         self.layer_3 = self.make_block()
#         self.layer_4 = self.make_block()
#         self.fc = nn.Linear(in_features=1, out_features=2)
    
#     def make_block(self):
#         return nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=0),
#                              nn.BatchNorm2d(32),
#                              nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
#                              nn.ReLU())
    
#     def forward(self, x):

#         x = self.layer_1(x)
#         x = self.layer_2(x)
#         x = self.layer_3(x)
#         x = self.layer_4(x)
#         out = self.fc(x)
#         return out