import torch
from torch.nn.parameter import Parameter

from mqbench.fake_quantize.quantize_base import QuantizeBase
from mqbench.utils import is_symmetric_quant, is_tracing_state
from mqbench.utils.hook import PerChannelLoadHook
import global_placeholder
from mqbench.fake_quantize.pure_hooker import get_ema_value, get_axiswise_max
class LearnableFakeQuantize(QuantizeBase):
    r""" This is an extension of the FakeQuantize module in fake_quantize.py, which
    supports more generalized lower-bit quantization and support learning of the scale
    and zero point parameters through backpropagation. For literature references,
    please see the class _LearnableFakeQuantizePerTensorOp.
    In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize
    module also includes the following attributes to support quantization parameter learning.
    """

    def __init__(self, observer, scale=1., zero_point=0., use_grad_scaling=True, **observer_kwargs):
        super(LearnableFakeQuantize, self).__init__(observer, **observer_kwargs)
        self.use_grad_scaling = use_grad_scaling
        self.scale = Parameter(torch.tensor([scale]))
        self.register_buffer('zero_point', torch.tensor([zero_point]))  # NOTE 已改  这里就是不对劲，就应该是buffer，而且grad还会占用显存
        self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps]))
        # Check whether the module will load a state dict;
        # Initialize the shape of per-channel 'scale' and 'zero-point' before copying values
        self.load_state_dict_hook = PerChannelLoadHook(self)
        # NOTE test
        # self.register_buffer('range_reg_loss', torch.tensor([0]))
        self.compute_range_regularization = False


    @torch.jit.export
    def extra_repr(self):
        return 'fake_quant_enabled={}, observer_enabled={}, ' \
               'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
               'scale={}, zero_point={}'.format(
                   self.fake_quant_enabled, self.observer_enabled,
                   self.quant_min, self.quant_max,
                   self.dtype, self.qscheme, self.ch_axis, self.scale if self.ch_axis == -1 else 'List[%s]' % str(self.scale.shape),
                   self.zero_point if self.ch_axis == -1 else 'List')

    def forward(self, X):
        # Learnable fake quantize have to zero_point.float() to make it learnable.
        if self.observer_enabled[0] == 1:
            self.activation_post_process(X.detach())
            _scale, _zero_point = self.activation_post_process.calculate_qparams()
            _scale = _scale.to(self.scale.device)
            _zero_point = _zero_point.to(self.zero_point.device)

            if self.ch_axis != -1:
                self.scale.data = torch.ones_like(_scale)
                self.zero_point.data = torch.zeros_like(_zero_point.float())

            self.scale.data.copy_(_scale)
            self.zero_point.data.copy_(_zero_point.float())
            
            # if self.compute_range_regularization:
            #     # 计算std，初始化margin
            #     self.regular_margin.data.copy_(2 * X.std())
        else:
            # if self.compute_range_regularization:
            #     # 计算std，初始化margin
            #     self.regular_margin.data.abs_()  # 要求绝对化
            self.scale.data.abs_()
            self.scale.data.clamp_(min=self.eps.item())

        # TODO 写能求最小化量化误差的代码
        X_old = X
        # if self.compute_range_regularization:
        #     X = grad_scale(X, 1+self.scale.detach())


        if self.fake_quant_enabled[0] == 1:
            if is_symmetric_quant(self.qscheme):
                self.zero_point.data.zero_()
            else:
                self.zero_point.data.clamp_(self.quant_min, self.quant_max).float()

            if self.is_per_channel:
                if self.use_grad_scaling:
                    grad_factor = 1.0 / (X.numel() / X.shape[self.ch_axis] * self.quant_max) ** 0.5
                else:
                    grad_factor = 1.0
                if is_tracing_state():
                    X = FakeQuantizeLearnablePerchannelAffine.apply(
                        X, self.scale, self.zero_point, self.ch_axis,
                        self.quant_min, self.quant_max, grad_factor)
                else:
                    X = _fake_quantize_learnable_per_channel_affine_training(
                        X, self.scale, self.zero_point, self.ch_axis,
                        self.quant_min, self.quant_max, grad_factor)
            else:
                if self.use_grad_scaling:
                    grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5
                else:
                    grad_factor = 1.0
                X = torch._fake_quantize_learnable_per_tensor_affine(  # 原装
                    X, self.scale, self.zero_point,
                    self.quant_min, self.quant_max, grad_factor)
                # X = _fake_quantize_learnable_per_tensor_affine_training(X, self.scale, self.zero_point, self.quant_min, self.quant_max, grad_factor)
            # NOTE 算
            # self.input = X_old.detach()
            
            # if self.compute_range_regularization and hasattr(self, 'identity'):
            # # #     # self.range_reg_loss = (torch.norm(X_old - X, p="fro", dim=1) ** 2).mean()  # 这玩意也不行了
                
            # # #     # gap = (X_old - X_old.min())/(X_old.max() - X_old.min()) - (X - X.min())/(X.max() - X.min())
            # # #     # gap = ((X_old - X) / self.scale + self.zero_point) / (self.quant_max - self.quant_min + 1)
            # # #     scale = self.scale.detach()
            # # #     zero_point = self.zero_point.detach()
            # # #     # scale = grad_scale(scale, grad_factor)
            # # #     # zero_point = grad_scale(zero_point, grad_factor)
            # # #     gap = ((X_old - X) / scale)
            # # #     # self.range_reg_loss = (torch.norm(X_old - X, p="fro", dim=1) ** 2).mean()  # 这玩意也不行了
            # # #     self.range_reg_loss = (gap.abs()).mean()
            # # #     # self.range_reg_loss = (gap ** 2).mean()   # 拉爆了
            # #     diff = torch.max(X_old.abs() - self.regular_margin)
            # #     diff = torch.where(diff < 0., torch.zeros_like(diff), diff)
            # #     self.range_reg_loss = self.regular_margin + diff + 1/global_placeholder.quant_bit * self.scale.detach() * self.identity * X_old.std()
            # #     # self.range_reg_loss = self.regular_margin + diff
                
            #     # self.range_reg_loss =1/global_placeholder.quant_bit * self.scale.detach() * self.identity * X.std()
            #     scale = grad_scale(self.scale, grad_factor)
                
            #     self.range_reg_loss =(1/global_placeholder.quant_bit * scale) ** 2


        return X




def _fake_quantize_learnable_per_tensor_affine_training(x, scale, zero_point, quant_min, quant_max, grad_factor):
    zero_point = (zero_point.round() - zero_point).detach() + zero_point
    scale = grad_scale(scale, grad_factor)
    zero_point = grad_scale(zero_point, grad_factor)
    x = x / scale + zero_point
    x = (x.round() - x).detach() + x
    x = torch.clamp(x, quant_min, quant_max)
    return (x - zero_point) * scale

def _fake_quantize_learnable_per_channel_affine_training(x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor):
    zero_point = (zero_point.round() - zero_point).detach() + zero_point
    new_shape = [1] * len(x.shape)
    new_shape[ch_axis] = x.shape[ch_axis]
    scale = grad_scale(scale, grad_factor).reshape(new_shape)
    zero_point = grad_scale(zero_point, grad_factor).reshape(new_shape)
    x = x / scale + zero_point
    x = (x.round() - x).detach() + x
    x = torch.clamp(x, quant_min, quant_max)
    return (x - zero_point) * scale


def grad_scale(t, scale):
    return (t - (t * scale)).detach() + (t * scale)


class FakeQuantizeLearnablePerchannelAffine(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor):
        return _fake_quantize_learnable_per_channel_affine_training(x, scale, zero_point, ch_axis,
                                                                    quant_min, quant_max, grad_factor)

    @staticmethod
    def symbolic(g, x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor):
        return g.op("::FakeQuantizeLearnablePerchannelAffine", x, scale, zero_point, quant_min_i=quant_min, quant_max_i=quant_max)

class ConstrainedLearnableFakeQuantize(QuantizeBase):
    r""" This is an extension of the FakeQuantize module in fake_quantize.py, which
    supports more generalized lower-bit quantization and support learning of the scale
    and zero point parameters through backpropagation. For literature references,
    please see the class _LearnableFakeQuantizePerTensorOp.
    In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize
    module also includes the following attributes to support quantization parameter learning.
    """

    def __init__(self, observer, scale=1., zero_point=0., use_grad_scaling=True, **observer_kwargs):
        super(ConstrainedLearnableFakeQuantize, self).__init__(observer, **observer_kwargs)
        self.use_grad_scaling = use_grad_scaling
        self.scale = Parameter(torch.tensor([scale]))
        self.register_buffer('zero_point', torch.tensor([zero_point]))  # NOTE 已改  这里就是不对劲，就应该是buffer，而且grad还会占用显存
        self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps]))
        # Check whether the module will load a state dict;
        # Initialize the shape of per-channel 'scale' and 'zero-point' before copying values
        self.load_state_dict_hook = PerChannelLoadHook(self)
        # NOTE test
        # self.register_buffer('range_reg_loss', torch.tensor([0]))
        self.compute_range_regularization = False
        self.is_activation = False
        self.accumulation = None
        self.register_buffer('regular_margin', torch.tensor([1.]))


    @torch.jit.export
    def extra_repr(self):
        return 'fake_quant_enabled={}, observer_enabled={}, ' \
               'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
               'scale={}, zero_point={}'.format(
                   self.fake_quant_enabled, self.observer_enabled,
                   self.quant_min, self.quant_max,
                   self.dtype, self.qscheme, self.ch_axis, self.scale if self.ch_axis == -1 else 'List[%s]' % str(self.scale.shape),
                   self.zero_point if self.ch_axis == -1 else 'List')

    def forward(self, X):
        # Learnable fake quantize have to zero_point.float() to make it learnable.
        if self.observer_enabled[0] == 1:
            self.activation_post_process(X.detach())
            _scale, _zero_point = self.activation_post_process.calculate_qparams()
            _scale = _scale.to(self.scale.device)
            _zero_point = _zero_point.to(self.zero_point.device)

            if self.ch_axis != -1:
                self.scale.data = torch.ones_like(_scale)
                self.zero_point.data = torch.zeros_like(_zero_point.float())

            self.scale.data.copy_(_scale)
            self.zero_point.data.copy_(_zero_point.float())
            
        else:
            self.scale.data.abs_()
            self.scale.data.clamp_(min=self.eps.item())
            self.regular_margin.data.abs_()
            self.regular_margin.data.clamp_(min=1e-8)

        # TODO 写能求最小化量化误差的代码
        X_old = X
        # if self.compute_range_regularization:
        #     X = grad_scale(X, 1+self.scale.detach())


        if self.fake_quant_enabled[0] == 1:
            if self.compute_range_regularization:
                # diff = torch.max(X.abs() - self.regular_margin)
                # diff = torch.where(diff < 0., torch.zeros_like(diff), diff)
                # self.range_reg_loss = self.regular_margin + diff + X.std() ** 2  # 这样子一开始直接崩
                # # self.range_reg_loss = self.regular_margin + diff
                
                # self.range_reg_loss = X.std().detach() * torch.max(X.abs())  # 注意 向量范数和矩阵范数的区别。这是向量范数。目前最好的
                
                if self.is_activation:
                    # activation--
                    # # per tensor
                    # 两件事：拉齐各个statistic、减少离群点(改善统计分布)、--------
                    # 拉齐，可以是(max - ema_max) / ema_max的平方
                    
                    batch_mean = torch.mean(X, dim=[1,2,3], keepdim=True)
                    batch_std = torch.std(X, dim=[1,2,3], keepdim=True)
                    batch_max = get_axiswise_max(X.abs(), ch_axis=0)  # 这个细节和R2完全不一样。R2是超过的就惩罚。我这只惩罚最值。
                    
                    # NOTE 处理下batch_std的情况。输入tensor为全0只会发生在YOLOX中，因为其数据增广的方式，常常会有奇怪的输入
                    # normal_mask = batch_std > 0.0001  # 小于的情况就可以判定为输入tensor为全0 .
                    threshold = 1e-5
                    non_zero_batchs = torch.where(batch_std > threshold)[0]
                    batch_std = batch_std[non_zero_batchs]
                    batch_max = batch_max[non_zero_batchs]
                    
                    
                    # tensor_max = batch_max.mean()
                    tensor_std = batch_std.mean()
                    # update ema
                    ema = 0.9
                    self.regular_margin = get_ema_value(self.regular_margin, tensor_std.detach(), ema_ratio=0.9)
                    norm_gap = (batch_std - self.regular_margin) / self.regular_margin
                    inter_loss = norm_gap ** 2
                    out_loss = batch_max / (batch_std * 3)
                    
                    self.range_reg_loss = out_loss.mean() + inter_loss.mean()  # 所以就是，必须得一起用

                    if self.range_reg_loss > 1000:
                        print('locate!', self.range_reg_loss, self.regular_margin, out_loss, inter_loss)
                else:
                    # weight--
                    
                    mean = torch.mean(X)
                    std = torch.std(X)
                    kurtosis_val = torch.mean((((X - mean) / std) ** 4))  # 峰度越大代表越紧凑！！！
                    indicator = torch.max(X.abs()) / (std * 3)
                    
                    self.range_reg_loss =(kurtosis_val-1.8)**2 * torch.max(X.abs()) + self.regular_margin - self.regular_margin # 即关注那些峰度小的情况。
                    

                pass
        
            if is_symmetric_quant(self.qscheme):
                self.zero_point.data.zero_()
            else:
                self.zero_point.data.clamp_(self.quant_min, self.quant_max).float()

            if self.is_per_channel:
                if self.use_grad_scaling:
                    grad_factor = 1.0 / (X.numel() / X.shape[self.ch_axis] * self.quant_max) ** 0.5
                else:
                    grad_factor = 1.0
                if is_tracing_state():
                    X = FakeQuantizeLearnablePerchannelAffine.apply(
                        X, self.scale, self.zero_point, self.ch_axis,
                        self.quant_min, self.quant_max, grad_factor)
                else:
                    X = _fake_quantize_learnable_per_channel_affine_training(
                        X, self.scale, self.zero_point, self.ch_axis,
                        self.quant_min, self.quant_max, grad_factor)
            else:
                if self.use_grad_scaling:
                    grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5
                else:
                    grad_factor = 1.0
                X = torch._fake_quantize_learnable_per_tensor_affine(  # 原装
                    X, self.scale, self.zero_point,
                    self.quant_min, self.quant_max, grad_factor)
        return X

