import torch
from torch.autograd.function import InplaceFunction, Function
import torch.nn as nn
import torch.nn.functional as F
import math

import scipy.io as sio

from .quantize_bp_ef_qlinear_all4 import RecordGrad, calculate_qparams

def record_grad(x, index, iter):
        return RecordGrad().apply(x, index, iter)


class UniformQuantize(InplaceFunction):

    @classmethod
    def forward(cls, ctx, input, num_bits=8, min_value=None, max_value=None,
                stochastic=False, inplace=False, enforce_true_zero=False, num_chunks=None, out_half=False, index=None, iter=None):

        num_chunks = num_chunks = input.shape[
            0] if num_chunks is None else num_chunks
        if min_value is None or max_value is None:
            B = input.shape[0]
            y = input.view(B // num_chunks, -1)
        if min_value is None:
            min_value = y.min(-1)[0].mean(-1)  # C
            #min_value = float(input.view(input.size(0), -1).min(-1)[0].mean())
        if max_value is None:
            #max_value = float(input.view(input.size(0), -1).max(-1)[0].mean())
            max_value = y.max(-1)[0].mean(-1)  # C
        ctx.inplace = inplace
        ctx.num_bits = num_bits
        ctx.min_value = min_value
        ctx.max_value = max_value
        ctx.stochastic = stochastic

        if ctx.inplace:
            ctx.mark_dirty(input)
            output = input
        else:
            output = input.clone()

        # record
        # if iter is not None:
        #     qparams = calculate_qparams(
        #         input, num_bits=num_bits, flatten_dims=(2, -1), reduce_dim=None)
        #     if iter % 100 == 0:
        #         temp = qparams.range.to(device='cpu').numpy()
        #         sio.savemat('C:/wenjinguo/prj/cpt_cifar_3/cifar10/UQ_wide_all4_piecewise_78000_INTV4_percsquant_nonDSQ_noendmask_GAaroundBN4bit/gabn/sc_' + index + '_' + str(iter) + '.mat', 
        #                     {'data': temp[:, :, 0, 0]})

        qmin = 0.
        qmax = 2.**num_bits - 1.
        #import pdb; pdb.set_trace()
        scale = (max_value - min_value) / (qmax - qmin)

        scale = max(scale, 1e-8)

        if enforce_true_zero:
            initial_zero_point = qmin - min_value / scale
            zero_point = 0.
            # make zero exactly represented
            if initial_zero_point < qmin:
                zero_point = qmin
            elif initial_zero_point > qmax:
                zero_point = qmax
            else:
                zero_point = initial_zero_point
            zero_point = int(zero_point)
            output.div_(scale).add_(zero_point)
        else:
            output.add_(-min_value).div_(scale).add_(qmin)

        if ctx.stochastic:
            noise = output.new(output.shape).uniform_(-0.5, 0.5)
            output.add_(noise)
        output.clamp_(qmin, qmax).round_()  # quantize

        if enforce_true_zero:
            output.add_(-zero_point).mul_(scale)  # dequantize
        else:
            output.add_(-qmin).mul_(scale).add_(min_value)  # dequantize
        if out_half and num_bits <= 16:
            output = output.half()
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # straight-through estimator
        grad_input = grad_output
        return grad_input, None, None, None, None, None, None


class UniformQuantizeGrad(InplaceFunction):

    @classmethod
    def forward(cls, ctx, input, num_bits=8, min_value=None, max_value=None, stochastic=True, inplace=False, index=None, iter=None):
        ctx.inplace = inplace
        ctx.num_bits = num_bits
        ctx.min_value = min_value
        ctx.max_value = max_value
        ctx.stochastic = stochastic
        ctx.index = index
        ctx.iter = iter
        return input

    @staticmethod
    def backward(ctx, grad_output):
        if ctx.min_value is None:
            min_value = float(grad_output.min())
            # min_value = float(grad_output.view(
            # grad_output.size(0), -1).min(-1)[0].mean())
        else:
            min_value = ctx.min_value
        if ctx.max_value is None:
            max_value = float(grad_output.max())
            # max_value = float(grad_output.view(
            # grad_output.size(0), -1).max(-1)[0].mean())
        else:
            max_value = ctx.max_value
        grad_input = UniformQuantize().apply(grad_output, ctx.num_bits,
                                             min_value, max_value, ctx.stochastic, ctx.inplace, False, None, False, ctx.index, ctx.iter)
        return grad_input, None, None, None, None, None, None, None

def quantize(x, num_bits=8, min_value=None, max_value=None, num_chunks=None, stochastic=False, inplace=False):
    return UniformQuantize().apply(x, num_bits, min_value, max_value, num_chunks, stochastic, inplace)

def quantize_grad(x, num_bits=8, min_value=None, max_value=None, stochastic=True, inplace=False, index=None, iter=None):
    return UniformQuantizeGrad().apply(x, num_bits, min_value, max_value, stochastic, inplace, index, iter)


class QuantMeasure(nn.Module):
    """docstring for QuantMeasure."""

    def __init__(self, num_bits=8, momentum=0.1):
        super(QuantMeasure, self).__init__()
        self.register_buffer('running_min', torch.zeros(1))
        self.register_buffer('running_max', torch.zeros(1))
        self.momentum = momentum
        self.num_bits = num_bits

    def forward(self, input):
        if self.training:
            min_value = input.detach().view(
                input.size(0), -1).min(-1)[0].mean()
            max_value = input.detach().view(
                input.size(0), -1).max(-1)[0].mean()
            self.running_min.mul_(self.momentum).add_(
                min_value * (1 - self.momentum))
            self.running_max.mul_(self.momentum).add_(
                max_value * (1 - self.momentum))
        else:
            min_value = self.running_min
            max_value = self.running_max
        return quantize(input, self.num_bits, min_value=float(min_value), max_value=float(max_value), num_chunks=16)
    

class L1LN(nn.Module):
    def __init__(self, num_features, dim=1, momentum=0.1, affine=True, num_chunks=16, eps=1e-5):
        super(L1LN, self).__init__()
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.zeros(num_features))

        self.momentum = momentum
        self.dim = dim
        if affine:
            self.bias = nn.Parameter(torch.Tensor(num_features))
            self.weight = nn.Parameter(torch.Tensor(num_features))
        self.eps = eps
        self.num_chunks = num_chunks
        self.reset_params()

    def reset_params(self):
        if self.weight is not None:
            self.weight.data.uniform_()
        if self.bias is not None:
            self.bias.data.zero_()

    def forward(self, x):

        if self.training:
            B, C, L = x.shape
            y = x.permute(2, 0, 1).contiguous()

            mean = y.view(L, -1).mean(-1)  # C
            norm_1 = y.view(L, -1).abs().sum(-1)

            numbers = B * C
            scale = 1 / (norm_1 / numbers + self.eps)

            self.running_mean.detach().mul_(self.momentum).add_(
                mean * (1 - self.momentum))

            self.running_var.detach().mul_(self.momentum).add_(
                scale * (1 - self.momentum))
        else:
            mean = self.running_mean
            scale = self.running_var

        out = (x - mean.view(1, 1, mean.size(0))) * \
            scale.view(1, 1, scale.size(0))

        if self.weight is not None:
            out = out * self.weight.view(1, 1, self.weight.size(0))

        if self.bias is not None:
            out = out + self.bias.view(1, 1, self.bias.size(0))
            
        return out


class L2LN(nn.Module):
    def __init__(self, num_features, dim=1, affine=True, eps=1e-5):
        super(L2LN, self).__init__()
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.zeros(num_features))

        self.dim = dim
        if affine:
            self.bias = nn.Parameter(torch.Tensor(num_features))
            self.weight = nn.Parameter(torch.Tensor(num_features))
        self.eps = eps
        self.reset_params()

    def reset_params(self):
        if self.weight is not None:
            self.weight.data.uniform_()
        if self.bias is not None:
            self.bias.data.zero_()

    def forward(self, x):

        B, T, C = x.size()
        # calculate the mean
        mean = x.sum(-1, keepdim=True) / C # B,T,1
        # calculate the variance
        xshift = x - mean # B,T,C
        var = (xshift**2).sum(-1, keepdim=True) / C # B,T,1
        # calculate the inverse standard deviation: **0.5 is sqrt, **-0.5 is 1/sqrt
        rstd = (var + self.eps) ** -0.5 # B,T,1
        # normalize the input activations
        norm = xshift * rstd # B,T,C
        # scale and shift the normalized activations at the end
        out = norm * self.weight + self.bias # B,T,C
            
        return out
    

class L2LNQ(nn.Module):

    def __init__(self, num_features, dim=0, affine=True, eps=1e-5, num_bits=8, num_bits_grad=8):
        super(L2LNQ, self).__init__()

        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.zeros(num_features))

        self.dim = dim
        if affine:
            self.bias = nn.Parameter(torch.Tensor(num_features))
            self.weight = nn.Parameter(torch.Tensor(num_features))
        self.num_bits = num_bits
        self.num_bits_grad = num_bits_grad
        self.quantize_input = QuantMeasure(self.num_bits)
        self.eps = eps
        self.reset_params()

    def reset_params(self):
        if self.weight is not None:
            self.weight.data.uniform_()
        if self.bias is not None:
            self.bias.data.zero_()

    def forward(self, x):
        x = self.quantize_input(x)

        B, T, C = x.size()
        # calculate the mean
        mean = x.sum(-1, keepdim=True) / C # B,T,1
        # calculate the variance
        xshift = x - mean # B,T,C
        var = (xshift**2).sum(-1, keepdim=True) / C # B,T,1
        # calculate the inverse standard deviation: **0.5 is sqrt, **-0.5 is 1/sqrt
        rstd = (var + self.eps) ** -0.5 # B,T,1
        # quantize estd
        rstd = quantize(rstd, num_bits=self.num_bits, min_value=float(
            rstd.min()), max_value=float(rstd.max()))
        # normalize the input activations
        out = xshift * rstd # B,T,C

        if self.weight is not None:
            qweight = quantize(self.weight, num_bits=self.num_bits,
                               min_value=float(self.weight.min()),
                               max_value=float(self.weight.max()))
            out = out * qweight.view(1, 1, qweight.size(0))

        if self.bias is not None:
            qbias = quantize(self.bias, num_bits=self.num_bits)
            out = out + qbias.view(1, 1, qbias.size(0))
        if self.num_bits_grad is not None:
            out = quantize_grad(out, num_bits=self.num_bits_grad)

        return out


class L1LNQ(nn.Module):

    def __init__(self, num_features, dim=0, momentum=0.1, affine=True, eps=1e-5, num_bits=8, num_bits_grad=8):
        super(L1LNQ, self).__init__()

        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.zeros(num_features))

        self.momentum = momentum
        self.dim = dim
        if affine:
            self.bias = nn.Parameter(torch.Tensor(num_features))
            self.weight = nn.Parameter(torch.Tensor(num_features))
        self.num_bits = num_bits
        self.num_bits_grad = num_bits_grad
        self.quantize_input = QuantMeasure(self.num_bits)
        self.eps = eps
        self.reset_params()

    def reset_params(self):
        if self.weight is not None:
            self.weight.data.uniform_()
        if self.bias is not None:
            self.bias.data.zero_()

    def forward(self, x):
        x = self.quantize_input(x)

        if self.training:
            B, C, L = x.shape
            y = x.permute(2, 0, 1).contiguous()
            mean = y.view(L, -1).mean(-1)  # C
            # norm_1 = (y.view(C, -1) - mean[:, None]).abs().sum(-1)
            norm_1 = y.view(L, -1).abs().sum(-1)

            numbers = B * C
            scale = 1 / (norm_1 / numbers + self.eps)

            self.running_mean.detach().mul_(self.momentum).add_(
                mean * (1 - self.momentum))

            self.running_var.detach().mul_(self.momentum).add_(
                scale * (1 - self.momentum))
        else:
            mean = self.running_mean
            scale = self.running_var
        
        scale = quantize(scale, num_bits=self.num_bits, min_value=float(
            scale.min()), max_value=float(scale.max()))
        out = (x - mean.view(1, 1, mean.size(0))) * \
            scale.view(1, 1, scale.size(0))

        if self.weight is not None:
            qweight = quantize(self.weight, num_bits=self.num_bits,
                               min_value=float(self.weight.min()),
                               max_value=float(self.weight.max()))
            out = out * qweight.view(1, 1, qweight.size(0))

        if self.bias is not None:
            qbias = quantize(self.bias, num_bits=self.num_bits)
            out = out + qbias.view(1, 1, qbias.size(0))
        if self.num_bits_grad is not None:
            out = quantize_grad(out, num_bits=self.num_bits_grad)

        # if out.size(3) == 1 and out.size(2) == 1:
        #     out = out.squeeze(-1).squeeze(-1)

        # out = record_grad(out, index, iter)

        return out