import torch
from torch.nn.parameter import Parameter

from mqbench.fake_quantize.quantize_base import QuantizeBase
from mqbench.utils import is_symmetric_quant, is_tracing_state
from mqbench.utils.hook import PerChannelLoadHook
import global_placeholder
import numpy

class PureHooker(QuantizeBase):
    r""" 
    """

    def __init__(self, observer, scale=1., zero_point=0., use_grad_scaling=True, **observer_kwargs):
        super(PureHooker, self).__init__(observer, **observer_kwargs)
        # self.regular_margin = Parameter(torch.tensor([1.]))
        # self.use_grad_scaling = use_grad_scaling
        # self.scale = Parameter(torch.tensor([scale]))
        # self.register_buffer('zero_point', torch.tensor([zero_point]))  # NOTE 已改  这里就是不对劲，就应该是buffer，而且grad还会占用显存
        self.register_buffer('regular_margin', torch.tensor([1.]))
        # # Check whether the module will load a state dict;
        # # Initialize the shape of per-channel 'scale' and 'zero-point' before copying values
        # self.load_state_dict_hook = PerChannelLoadHook(self)
        # # NOTE test
        # # self.register_buffer('range_reg_loss', torch.tensor([0]))
        self.compute_range_regularization = False
        self.is_activation = False
        self.accumulation = None


    @torch.jit.export
    def extra_repr(self):
        return 'fake_quant_enabled={}, observer_enabled={}, ' \
               'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
               'scale={}, zero_point={}'.format(
                   self.fake_quant_enabled, self.observer_enabled,
                   self.quant_min, self.quant_max,
                   self.dtype, self.qscheme, self.ch_axis, self.scale if self.ch_axis == -1 else 'List[%s]' % str(self.scale.shape),
                   self.zero_point if self.ch_axis == -1 else 'List')

    def forward(self, X):
        if self.observer_enabled[0] == 1:
            # if not self.is_activation and hasattr(self, 'regular_margin'):
            #     del self.regular_margin
        
            # if self.compute_range_regularization and self.is_activation:
            # if self.compute_range_regularization:
            #     # 初始化margin
            #     batch_std = torch.std(X, dim=[1,2,3], keepdim=True)
            #     tensor_std = batch_std.mean()
            #     accumulation = get_ema_value(self.regular_margin, tensor_std.detach(), ema_ratio=0.9)
                
            #     self.regular_margin.data.copy_(accumulation)
            pass
        else:
            self.regular_margin.data.abs_()
            self.regular_margin.data.clamp_(min=1e-8)
            
            
        if self.fake_quant_enabled[0] == 1:
        # Learnable fake quantize have to zero_point.float() to make it learnable.
            if self.compute_range_regularization:
                # diff = torch.max(X.abs() - self.regular_margin)
                # diff = torch.where(diff < 0., torch.zeros_like(diff), diff)
                # self.range_reg_loss = self.regular_margin + diff + X.std() ** 2  # 这样子一开始直接崩
                # # self.range_reg_loss = self.regular_margin + diff
                
                # self.range_reg_loss = X.std().detach() * torch.max(X.abs())  # 注意 向量范数和矩阵范数的区别。这是向量范数。目前最好的
                
                if self.is_activation:
                    # activation--
                    # # per tensor
                    # mean = torch.mean(X)
                    # std = torch.std(X)
                    # kurtosis_val = torch.mean((((X - mean) / std) ** 4))  # 峰度越大代表越紧凑！！！
                    # self.range_reg_loss = (1 / kurtosis_val).detach() * torch.max(X.abs())  # 即关注那些峰度小的情况。
                    
                    # per batch 16*3*8*8 -> 16 wise
                    # EPS = 1e-16
                    # mean = torch.mean(X, dim=[1,2,3], keepdim=True)
                    # std = torch.std(X, dim=[1,2,3], keepdim=True)
                    # kurtosis_val = torch.mean((((X - mean) / (std+EPS)) ** 4), dim=[1,2,3], keepdim=True)  # 峰度越大代表越紧凑！！！
                    
                    # diff = get_axiswise_max(X.abs() - self.regular_margin, ch_axis=0)  # 这个细节和R2完全不一样。R2是超过的就惩罚。我这只惩罚最值。
                    # # diff = X.abs() - self.regular_margin
                    
                    # diff = torch.where(diff < 0., torch.zeros_like(diff), diff)  # TODO
                    # kurtosis_val = torch.where(kurtosis_val < 0.01, torch.ones_like(kurtosis_val), kurtosis_val)
                    # self.range_reg_loss = ((1 / (kurtosis_val)).detach() * diff).mean() + self.regular_margin  # 即关注那些峰度小的情况。同时惩罚margin  现在怀疑的是这里的margin出问题了
                    # self.range_reg_loss = ((kurtosis_val+EPS).detach() * diff).mean() + self.regular_margin + (X.abs().max() / (3*std) - 1) # 即关注那些峰度小的情况。同时惩罚margin
                    # self.range_reg_loss = ((1 / (kurtosis_val+EPS)).detach() * (diff + self.regular_margin)).mean()  # 即关注那些峰度小的情况。同时惩罚margin. margin被加权
                
                
                    # diff = X.abs() - self.regular_margin.abs()
                    # diff = get_axiswise_max(diff, ch_axis=0)  # 这个细节和R2完全不一样。R2是超过的就惩罚。我这只惩罚最值。
                    # # diff = X.abs() - self.regular_margin
                    
                    # diff = torch.where(diff > 0, diff, torch.zeros_like(diff))
                    # self.range_reg_loss = diff.mean() + self.regular_margin.abs()  # 即关注那些峰度小的情况。同时惩罚margin


                    # 两件事：拉齐各个statistic、减少离群点(改善统计分布)、--------
                    # 拉齐，可以是(max - ema_max) / ema_max的平方
                    
                    batch_mean = torch.mean(X, dim=[1,2,3], keepdim=True)
                    batch_std = torch.std(X, dim=[1,2,3], keepdim=True)
                    batch_max = get_axiswise_max(X.abs(), ch_axis=0)  # 这个细节和R2完全不一样。R2是超过的就惩罚。我这只惩罚最值。
                    
                    # NOTE 处理下batch_std的情况。输入tensor为全0只会发生在YOLOX中，因为其数据增广的方式，常常会有奇怪的输入
                    # normal_mask = batch_std > 0.0001  # 小于的情况就可以判定为输入tensor为全0 .
                    threshold = 1e-5
                    non_zero_batchs = torch.where(batch_std > threshold)[0]
                    batch_std = batch_std[non_zero_batchs]
                    batch_max = batch_max[non_zero_batchs]
                    
                    
                    # tensor_max = batch_max.mean()
                    tensor_std = batch_std.mean()
                    # update ema
                    ema = 0.9
                    self.regular_margin = get_ema_value(self.regular_margin, tensor_std.detach(), ema_ratio=ema)
                    norm_gap = (batch_std - self.regular_margin) / self.regular_margin
                    inter_loss = norm_gap ** 2
                    out_loss = batch_max / (batch_std * 3)
                    # out_loss = batch_max / (batch_std * 3) * 0
                    self.range_reg_loss = out_loss.mean() + inter_loss.mean()  # 所以就是，必须得一起用
                    
                    # self.range_reg_loss = self.range_reg_loss * 100
                    
                    
                    # self.range_reg_loss = 0

                    # # per batch+channel TODO 没法广播形式得到max
                    # EPS = 1e-16
                    # mean = torch.mean(X, dim=[2,3], keepdim=True)
                    # std = torch.std(X, dim=[2,3], keepdim=True)
                    # kurtosis_val = torch.mean((((X - mean) / (std+EPS)) ** 4), dim=[2,3], keepdim=True)  # 峰度越大代表越紧凑！！！
                    # self.range_reg_loss = ((1 / (kurtosis_val+EPS)).detach() * get_axiswise_max(X.abs(), ch_axis=[0,1])).mean()  # 即关注那些峰度小的情况。
                    # self.range_reg_loss = self.regular_margin - self.regular_margin # 即关注那些峰度小的情况。
                
                    if self.range_reg_loss > 1000:
                        print('locate!', self.range_reg_loss, self.regular_margin, out_loss, inter_loss)
                else:
                    # weight--
                    
                    # # per channel 所以，weight的channel是idx 0; act的channel是idx 1; TODO
                    # EPS = 1e-16
                    # mean = torch.mean(X, dim=[1,2,3], keepdim=True)
                    # std = torch.std(X, dim=[1,2,3], keepdim=True)
                    # kurtosis_val = torch.mean((((X - mean) / (std+EPS)) ** 4), dim=[1,2,3], keepdim=True)  # 峰度越大代表越紧凑！！！
                    # self.range_reg_loss = ((1 / (kurtosis_val+EPS)).detach() * get_axiswise_max(X.abs(), ch_axis=0)).mean() + self.regular_margin - self.regular_margin  # 即关注那些峰度小的情况。
                
                    # per tensor
                    # mean = torch.mean(X)
                    # std = torch.std(X)
                    # kurtosis_val = torch.mean((((X - mean) / std) ** 4))  # 峰度越大代表越紧凑！！！
                    # self.range_reg_loss = (1 / kurtosis_val).detach() * torch.max(X.abs()) + self.regular_margin - self.regular_margin # 即关注那些峰度小的情况。
                    
                    mean = torch.mean(X)
                    std = torch.std(X)
                    kurtosis_val = torch.mean((((X - mean) / std) ** 4))  # 峰度越大代表越紧凑！！！
                    indicator = torch.max(X.abs()) / (std * 3)
                    # self.range_reg_loss =indicator + self.regular_margin - self.regular_margin # 即关注那些峰度小的情况。
                    # self.range_reg_loss =torch.max(X.abs()) + self.regular_margin - self.regular_margin # 即关注那些峰度小的情况。
                    # self.range_reg_loss =indicator + self.regular_margin - self.regular_margin # 即关注那些峰度小的情况。
                    self.range_reg_loss =(kurtosis_val-1.8)**2 * torch.max(X.abs()) + self.regular_margin - self.regular_margin # 即关注那些峰度小的情况。
                    # self.range_reg_loss =(kurtosis_val-1.8)**2 * indicator + self.regular_margin - self.regular_margin # 即关注那些峰度小的情况。
                    # self.range_reg_loss = 0.1 * (kurtosis_val-1.8)**2 * torch.max(X.abs()) + self.regular_margin - self.regular_margin # 即关注那些峰度小的情况。
                    # self.range_reg_loss = (kurtosis_val-1.8)**2 + self.regular_margin - self.regular_margin # 即关注那些峰度小的情况。
                    
                    # self.range_reg_loss = 0 + self.regular_margin - self.regular_margin
                    
                    
                    
                    # self.range_reg_loss = self.regular_margin - self.regular_margin # 即关注那些峰度小的情况。
                    # ct_abs_max = X.abs().max() - mean
                    # ct_std_fac = 3*std - mean
                    # self.range_reg_loss =(ct_abs_max / (ct_std_fac) - 1) + self.regular_margin - self.regular_margin
                    
                    
                    # # diff = X.abs() - self.regular_margin.abs()
                    # # diff = diff.max()
                    # # diff = torch.where(diff > 0, diff, torch.zeros_like(diff))
                    # # self.range_reg_loss = diff + self.regular_margin.abs() + (X.abs().max() / (3*std) - 1)



                pass
        
        return X


def get_indicator(X):
    # ind = 2 * X.std()
    ind = X.abs().max()
    return ind

def get_ema_value(main_value, new_value, ema_ratio):
    out = ema_ratio * main_value + (1 - ema_ratio) * new_value.detach()
    return out

def get_axiswise_max(x, ch_axis=1):
    # ch_axis表示在该axis上的粒度。其他axis全归1
    x_dim = x.size()
    new_axis_list = [i for i in range(len(x_dim))]
    new_axis_list[ch_axis] = 0
    new_axis_list[0] = ch_axis
    permu_x = x.permute(new_axis_list)
    flat_x = torch.flatten(permu_x, start_dim=1)
    y = torch.max(flat_x, dim=1)[0]
    
    shape = [1,1,1,1]
    shape[ch_axis] = x.shape[ch_axis]
    out = y.reshape(shape)
    return out

def grad_scale(t, scale):
    return (t - (t * scale)).detach() + (t * scale)

