from collections import namedtuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import InplaceFunction, Function

import numpy as np
import scipy.io as sio

from .quantizer import Quantizer, PersampleDivisionQuantizer, DivisionQuantizer, DivisionQuantizer2D, DivisionShiftQuantizer2D


class QuantizationConfig:
    def __init__(self):
        self.quantize_activation = True
        self.quantize_weights = True
        self.quantize_gradient = True
        self.activation_num_bits = 6
        self.weight_fw_num_bits = 6
        self.weight_bw_num_bits = 6
        self.bias_num_bits = 6
        self.backward_num_bits = 6
    
    def activation_quantizer(self):
        return lambda x: Quantizer(x, num_bits=self.activation_num_bits, flatten_dims=(1, -1), signed=True, stochastic=True)
        # return lambda x: DivisionShiftQuantizer2D(x, num_bits=self.backward_num_bits, signed=True, stochastic=True, groups_s=6, groups_c=6, permute=(1, 0, 2, 3))

    # def activation_quantizer(self):
    #     return lambda x: DivisionQuantizer(x, num_bits=self.activation_num_bits, signed=True, stochastic=True, groups=4, permute=(1, 0, 2, 3))
    
    def activation_fc_quantizer(self):
        # return lambda x: DivisionQuantizer(x, num_bits=self.activation_num_bits, signed=True, stochastic=True, groups=4, permute=(0, 2, 1))
        return lambda x: DivisionShiftQuantizer2D(x, num_bits=self.backward_num_bits, signed=True, stochastic=True, groups_s=4, groups_c=4, permute=(0, 2, 1))

    def weight_quantizer(self):
        return lambda x: Quantizer(x, num_bits=self.weight_fw_num_bits, flatten_dims=(1, -1), signed=True, stochastic=True)
    
    def weight_fc_quantizer(self):
        return lambda x: Quantizer(x, num_bits=self.weight_fw_num_bits, flatten_dims=(1, -1), signed=True, stochastic=True, is_zero_point=False)

    def bias_quantizer(self):
        return lambda x: Quantizer(x, num_bits=self.bias_num_bits, flatten_dims=(0, -1), signed=True, stochastic=True)

    def weight_bw_quantizer(self):
        return lambda x: Quantizer(x, num_bits=self.weight_bw_num_bits, flatten_dims=(1, -1), signed=True, stochastic=True, permute=(1, 0, 2, 3))
    
    def weight_bw_fc_quantizer(self):
        return lambda x: Quantizer(x, num_bits=self.weight_bw_num_bits, flatten_dims=(1, -1), signed=True, stochastic=True, permute=(1, 0), is_zero_point=False)
    
    def activation_gradient_quantizer(self):
        # return lambda x: Quantizer(x, num_bits=self.backward_num_bits, flatten_dims=(1, -1), signed=True, stochastic=True)
        # return lambda x: DivisionQuantizer2D(x, num_bits=self.backward_num_bits, signed=True, stochastic=True, groups_s=4, groups_c=4)
        return lambda x: DivisionShiftQuantizer2D(x, num_bits=self.backward_num_bits, signed=True, stochastic=True, groups_s=4, groups_c=4)

    def activation_gradient_fc_quantizer(self):
        return lambda x: DivisionShiftQuantizer2D(x, num_bits=self.backward_num_bits, signed=True, stochastic=True, groups_s=4, groups_c=4, permute=(0, 2, 1))
        # return lambda x: Quantizer(x, num_bits=self.backward_num_bits, flatten_dims=(1, -1), signed=True, stochastic=True)

config = QuantizationConfig()

QParams = namedtuple('QParams', ['range', 'zero_point', 'num_bits'])

_DEFAULT_FLATTEN = (1, -1)
_DEFAULT_FLATTEN_GRAD = (0, -1)


class UniformQuantize(InplaceFunction):

    @staticmethod
    def forward(ctx, input, Quantizer, inplace=False):

        ctx.inplace = inplace

        if ctx.inplace:
            ctx.mark_dirty(input)
            output = input
        else:
            output = input.clone()

        with torch.no_grad():
            quantizer = Quantizer(output)
            # quantize
            output = quantizer.forward()
            # dequantize
            output = quantizer.inverse(output)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        # straight-through estimator
        grad_input = grad_output
        return grad_input, None, None
    

def cal_scale(num_bits_s, l_shift, x):
    length = len(x)
    x_int = x * 2 ** l_shift
    for i in range(length):
        if x_int[i] > 2 ** num_bits_s - 1:
            x_int[i] = 2 ** num_bits_s - 1
            # print("scale wrong!!! /n")
        elif x_int[i] < - 2 ** num_bits_s:
            x_int[i] = - 2 ** num_bits_s
            # print("scale wrong!!! /n")
    return x_int


def cal_scale_persample(num_bits_s, l_shift_max, x):
    l_shift = - torch.floor(torch.log2(x / 2 ** (num_bits_s - 1)))
    l_shift[l_shift > l_shift_max] = l_shift_max
    x_int = torch.round(x * 2 ** l_shift)
    return x_int, l_shift


class RecordGrad(InplaceFunction):

    @staticmethod
    def forward(ctx, input, index, iter):
        ctx.index = index
        ctx.iter = iter
        return input

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = record_grad(grad_output, ctx.index, ctx.iter)
        
        return grad_output, None, None, None, None, None, None, None


def record_grad(x, index, iter):
    return RecordTensor().apply(x, index, iter)
    

class RecordTensor(InplaceFunction):

    @staticmethod
    def forward(ctx, input, index, iter):
        ctx.index = index
        ctx.iter = iter

        output = input

        qparams = calculate_qparams(
            output, num_bits=4, flatten_dims=(2, -1), reduce_dim=None)
         
        if ctx.iter % 100 == 0:
            temp = qparams.range.to(device='cpu').numpy()
            sio.savemat('C:/wenjinguo/prj/cpt_cifar_3/cifar10/UQ_wide_all4_piecewise_78000_INTV4_percsquant_nonDSQ_noendmask_GAaroundBN/gabn/sc_' + ctx.index + '_' + str(ctx.iter) + '.mat', 
                        {'data': temp[:, 0, 0, 0]})

        return output

    @staticmethod
    def backward(ctx, grad_output):
        # straight-through estimator
        grad_input = grad_output
        return grad_input, None, None, None, None, None, None, None, None
    

class UniformQuantizeGrad(InplaceFunction):

    @staticmethod
    def forward(ctx, input, Quantizer):
        ctx.Quantizer = Quantizer
        ctx.inplace = False
        return input

    @staticmethod
    def backward(ctx, grad_output):
        with torch.no_grad():
            grad_input = quantize(grad_output, ctx.Quantizer)
        return grad_input, None


def quantize(x, Quantizer, inplace=False):
    return UniformQuantize().apply(x, Quantizer, inplace)


def quantize_grad(x, Quantizer):
    return UniformQuantizeGrad().apply(x, Quantizer)


class QuantMeasure(nn.Module):
    """docstring for QuantMeasure."""

    def __init__(self, inplace=False, quantizer=config.activation_quantizer()):
        super(QuantMeasure, self).__init__()
        self.inplace = inplace
        self.quantizer = quantizer

    def forward(self, input):
        q_input = quantize(input, self.quantizer, inplace=self.inplace)
        return q_input


class QConv2d(nn.Conv2d):
    """docstring for QConv2d with channel smoothing."""

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(QConv2d, self).__init__(in_channels, out_channels, kernel_size,
                                      stride, padding, dilation, groups, bias)

        self.quantize_input = QuantMeasure()

    def forward(self, input, index=None, iter=None):
        if config.quantize_activation:
            qinput = self.quantize_input(input)
        else:
            qinput = input

        if config.quantize_weights:     # TODO weight quantization scheme...
            qweight = quantize(self.weight, config.weight_quantizer())
            qweight_bp = quantize(self.weight, config.weight_bw_quantizer())
            if self.bias is not None:
                qbias = quantize(self.bias, config.bias_quantizer())
            else:
                qbias = None
            qbias = self.bias
        else:
            qweight = self.weight
            qbias = self.bias

        self.qweight = qweight

        self.iact = qinput

        # output = F.conv2d(qinput, qweight, qbias, self.stride,
        #                     self.padding, self.dilation, self.groups)
        # self.act = output
        # output = quantize_grad(output, config.activation_gradient_quantizer())

        output = self.conv2d_weight_biprec(qinput, qweight, qweight_bp, qbias, self.stride, self.padding, self.dilation, self.groups)
        
        return output
    
    def conv2d_weight_biprec(self, qinput, qweight_fw, qweight_bp, bias=None, stride=1, padding=0, dilation=1, groups=1):
        # fw; bp of weights
        out1 = F.conv2d(qinput.detach(), qweight_fw, bias if bias is not None else None,
                        stride, padding, dilation, groups)
        # bp of acts
        out2 = F.conv2d(qinput, qweight_bp.detach(), bias.detach() if bias is not None else None,
                        stride, padding, dilation, groups)
        
        self.act = out1

        out1 = quantize_grad(out1, config.activation_gradient_quantizer())
        out2 = quantize_grad(out2, config.activation_gradient_quantizer())
        return out1 + out2 - out2.detach()


class QLinear(nn.Linear):
    """docstring for Linear."""

    def __init__(self, in_features, out_features, bias=True,):
        super(QLinear, self).__init__(in_features, out_features, bias)
        self.quantize_input = QuantMeasure(quantizer=config.activation_fc_quantizer())

    def forward(self, input, index=None, iter=None):
        if config.quantize_activation:
            qinput = self.quantize_input(input)
        else:
            qinput = input

        if config.quantize_weights:
            qweight = quantize(self.weight, config.weight_fc_quantizer())
            qweight_bp = quantize(self.weight, config.weight_bw_fc_quantizer())
            if self.bias is not None:
                qbias = quantize(self.bias, config.bias_quantizer())
            else:
                qbias = None
        else:
            qweight = self.weight
            qbias = self.bias

        output = F.linear(qinput, qweight, qbias)
        output = quantize_grad(output, config.activation_gradient_fc_quantizer())

        return output


class QBatchNorm2D(nn.BatchNorm2d):
    def __init__(self, num_features):
        super(QBatchNorm2D, self).__init__(num_features)
        self.quantize_input = QuantMeasure()

    def forward(self, input):       # TODO: weight is not quantized
        self._check_input_dim(input)
        if config.quantize_activation:
            qinput = self.quantize_input(input)
        else:
            qinput = input

        # if config.quantize_weights:
        #     qweight = quantize(self.weight, config.bias_preconditioner())
        #     qbias = quantize(self.bias, config.bias_preconditioner())
        # else:
        #     qweight = self.weight
        #     qbias = self.bias

        qweight = self.weight
        qbias = self.bias

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that if gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked = self.num_batches_tracked + 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        return F.batch_norm(
            input, self.running_mean, self.running_var, qweight, qbias,
            self.training or not self.track_running_stats,
            exponential_average_factor, self.eps)
