import math
from typing import Tuple

import torch
from torch.quantization.observer import _ObserverBase

from mqbench.fake_quantize.quantize_base import _version_under_1100 
from mqbench.utils import sync_tensor, pot_quantization, is_symmetric_quant
# from mqbench.utils.logger import logger
from mqbench.utils.hook import PerChannelLoadHook
import warnings

class ObserverBase(_ObserverBase):
    '''
        Support per-tensor / per-channel.
        dtype: quant min/max can be infered using dtype, we actually do not need this.
        qscheme: quantization scheme
        reduce_range: special for fbgemm to avoid overflow
        quant_min: fix point value min
        quant_max: fix point value max
        ch_axis: per-channel axis or per-tensor(-1)
        above is similiar to torch observer.
        pot_scale: indecate wheather scale is power of two.
    '''

    min_val: torch.Tensor
    max_val: torch.Tensor

    def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
                 reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False,
                 factory_kwargs=None):
        # Since torch 1.10, function calculate_qmin_qmax is not a member function of observer,
        # but import from utils. It is hard to control. We use try...except here.
        stored_min, sotred_max = quant_min, quant_max
        if quant_max is not None and quant_min is not None and (quant_max - quant_min + 1 > 256):
            quant_min, quant_max = -128, 127
        super(ObserverBase, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max)
        self.quant_min = stored_min
        self.quant_max = sotred_max
        self.quant_min, self.quant_max = self._calculate_qmin_qmax()
        self.ch_axis = ch_axis
        self.pot_scale = pot_scale
        self.register_buffer("min_val", torch.tensor(float("inf")))
        self.register_buffer("max_val", torch.tensor(float("-inf")))
        self.load_state_dict_hook = PerChannelLoadHook(self)

    @torch.jit.export
    def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Calculates the quantization parameters."""
        scale, zero_point = self._calculate_qparams(self.min_val, self.max_val)
        scale.data = sync_tensor(scale).data
        zero_point.data = sync_tensor(zero_point).data
        if self.pot_scale:
            scale = pot_quantization(scale)
        return scale, zero_point
    
    @torch.jit.export
    def _calculate_qparams(self, min_val: torch.Tensor, max_val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:  # 从源码继承而来
        r"""Calculates the quantization parameters, given min and max
        value tensors. Works for both per tensor and per channel cases

        Args:
            min_val: Minimum values per channel
            max_val: Maximum values per channel

        Returns:
            scales: Scales tensor of shape (#channels,)
            zero_points: Zero points tensor of shape (#channels,)
        """
        if min_val.numel() == 0 or max_val.numel() == 0:
            warnings.warn(
                "must run observer before calling calculate_qparams.\
                                    Returning default scale and zero point "
            )
            return torch.tensor([1.0]), torch.tensor([0])

        if min_val.dim() == 0 or max_val.dim() == 0:
            if min_val == float('inf') and max_val == float('-inf'):
                warnings.warn(
                    "must run observer before calling calculate_qparams.\
                                        Returning default scale and zero point "
                )
                return torch.tensor([1.0]), torch.tensor([0])

            assert min_val <= max_val, "min {} should be less than max {}".format(
                min_val, max_val
            )
        else:
            assert torch.all(min_val <= max_val), "min {} should be less than max {}".format(
                min_val, max_val
            )

        quant_min, quant_max = self._calculate_qmin_qmax()
        min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
        max_val_pos = torch.max(max_val, torch.zeros_like(max_val))

        device = min_val_neg.device
        scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
        zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)

        if self.qscheme == torch.per_tensor_symmetric or self.qscheme == torch.per_channel_symmetric:
            max_val_pos = torch.max(-min_val_neg, max_val_pos)
            
            if self.dtype == torch.quint8:
                # 非负对称量化
                # if self.has_customized_qrange:
                #     # When customized quantization range is used, down-rounded midpoint of the range is chosen.
                #     zero_point = zero_point.new_full(zero_point.size(), (quant_min + quant_max) // 2)
                # else:
                #     zero_point = zero_point.new_full(zero_point.size(), 128)
                scale = (max_val - 0.) / float(quant_max - quant_min)
                scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))
                
            elif self.dtype == torch.qint8:
                # 对称量化
                scale = max_val_pos / (float(quant_max - quant_min) / 2)  # 这玩意除2，意在
                scale = torch.max(scale, self.eps)
            else:
                raise NotImplementedError
        elif self.qscheme == torch.per_channel_affine_float_qparams:
            scale = (max_val - min_val) / float(quant_max - quant_min)
            scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))
            # We use the quantize function
            # xq = Round(Xf * inv_scale + zero_point),
            # setting zero_point to (-1 * min *inv_scale) we get
            # Xq = Round((Xf - min) * inv_scale)
            zero_point = -1 * min_val / scale
        else:
            scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
            scale = torch.max(scale, self.eps)
            zero_point = quant_min - torch.round(min_val_neg / scale)
            zero_point = torch.clamp(zero_point, quant_min, quant_max)

        # For scalar values, cast them to Tensors of size 1 to keep the shape
        # consistent with default values in FakeQuantize.
        if len(scale.shape) == 0:
            # TODO: switch to scale.item() after adding JIT support
            scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
        if len(zero_point.shape) == 0:
            # TODO: switch to zero_point.item() after adding JIT support
            zero_point = torch.tensor([int(zero_point)], dtype=zero_point.dtype, device=device)
            if self.qscheme == torch.per_channel_affine_float_qparams:
                zero_point = torch.tensor([float(zero_point)], dtype=zero_point.dtype, device=device)

        return scale, zero_point
    
    @torch.jit.export
    def _calculate_qmin_qmax(self) -> Tuple[int, int]:
        r"""Calculates actual qmin and qmax based on the quantization range,
        observer datatype and if range is reduced.
        """
        if self.has_customized_qrange:
            quant_min, quant_max = self.quant_min, self.quant_max
        else:
            # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used.
            if self.dtype == torch.qint8:
                if self.reduce_range:
                    quant_min, quant_max = -64, 63
                else:
                    quant_min, quant_max = -128, 127
            elif self.dtype == torch.quint8:
                if self.reduce_range:
                    quant_min, quant_max = 0, 127
                else:
                    quant_min, quant_max = 0, 255
            else:
                quant_min, quant_max = 0, 15
        return quant_min, quant_max

    @torch.jit.export
    def extra_repr(self):
        return "min_val={}, max_val={} ch_axis={} pot={}".format(self.min_val if self.ch_axis == -1 else 'List',
                                                                 self.max_val if self.ch_axis == -1 else 'List',
                                                                 self.ch_axis, self.pot_scale)


class MinMaxObserver(ObserverBase):
    '''
    Calculate minmax of whole calibration dataset.
    '''

    def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
                 reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False,
                 factory_kwargs=None):
        super(MinMaxObserver, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max,
                                             ch_axis, pot_scale, factory_kwargs)

    def forward(self, x_orig):
        r"""Records the running minimum and maximum of ``x``."""
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.to(self.min_val.dtype)
        if self.ch_axis == -1:
            min_val_cur, max_val_cur = torch._aminmax(x)
        else:
            x_dim = x.size()
            new_axis_list = [i for i in range(len(x_dim))]
            new_axis_list[self.ch_axis] = 0
            new_axis_list[0] = self.ch_axis
            y = x.permute(new_axis_list)
            y = torch.flatten(y, start_dim=1)
            min_val_cur, max_val_cur = torch._aminmax(y, 1)
        self.min_val = torch.min(self.min_val, min_val_cur)
        self.max_val = torch.max(self.max_val, max_val_cur)

        return x


class MinMaxFloorObserver(ObserverBase):
    '''
    Calculate minmax of whole calibration dataset with floor but round.
    '''

    def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
                 reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False,
                 factory_kwargs=None):
        super(MinMaxFloorObserver, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max,
                                                  ch_axis, pot_scale, factory_kwargs)
        '''
        The quant_type could be 'input', 'param', 'tensor', the co-responding
        range is 1, 5, 5,
        mth is 2, 3, 2
        '''
        self.quant_type = None
        


    def forward(self, x_orig):
        r"""Records the running minimum and maximum of ``x``."""
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.to(self.min_val.dtype)
        if self.ch_axis == -1:
            min_val_cur, max_val_cur = torch._aminmax(x)
        else:
            from global_placeholder import logger
            logger.warn('The per-tensor observer does not support per-channel min-max!')
            min_val_cur, max_val_cur = torch._aminmax(x)

        self.min_val = min_val_cur
        self.max_val = max_val_cur
        self._x = x
        return x

    def calculate_qparams(self):
        if self.quant_type is None:
            raise ValueError('You should set the observer type before forward!')
        else:
            scale_range = 1 if self.quant_type == 'input' else 5
            mth = 3 if self.quant_type == 'param' else 2
        scale, zero_point = self._calculate_qparams(self.min_val, self.max_val)
        scale.data = scale.data * 0 + max(self.min_val / self.quant_min, self.max_val / self.quant_max)
        if scale < 2 ** -15:
            max_scale = 0
        else:
            max_scale = 1 / scale
            max_scale = torch.floor(max_scale.log2())
        min_loss = torch.tensor([float('inf')])
        final_scale = max_scale
        max_scale = int(max_scale)
        for s in range(max_scale, max_scale + scale_range):
            _s = 1 / 2 ** s
            if mth == 3:
                new_x = _s * torch.clamp(torch.round(self._x / _s), self.quant_min, self.quant_max)
            elif mth == 2:
                new_x = torch.clamp(self._x / _s, self.quant_min, self.quant_max)
                new_x = torch.where((new_x < 0) & (new_x - new_x.floor() == 0.5), new_x.ceil(), new_x.round())
                new_x *= _s
            loss = ((new_x - self._x)**2).sum()
            min_loss = min_loss.to(loss.device)
            if loss < min_loss:
                min_loss = loss
                final_scale = s
        final_scale = min(final_scale, 12)
        scale = scale.data * 0 + 1 / (2 ** final_scale)
        zero_point = torch.zeros_like(zero_point)
        if not is_symmetric_quant(self.qscheme):
            if self.min_val >= 0.:
                zero_point = self.quant_min - torch.round(self.min_val / scale)
        sync_tensor(scale)
        sync_tensor(zero_point)
        return scale, zero_point

    def set_quant_type(self, qtype):
        self.quant_type = qtype


class EMAMinMaxObserver(ObserverBase):
    """Moving average min/max among batches.
    """

    def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False,
                 quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, ema_ratio=0.9,
                 factory_kwargs=None):
        super(EMAMinMaxObserver, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max,
                                                ch_axis, pot_scale, factory_kwargs)
        self.ema_ratio = ema_ratio

    def forward(self, x_orig):
        r"""Records the running minimum and maximum of ``x``."""
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.to(self.min_val.dtype)
        if self.ch_axis == -1:
            min_val_cur, max_val_cur = torch._aminmax(x)
        else:
            x_dim = x.size()
            new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
            new_axis_list[self.ch_axis] = 0
            new_axis_list[0] = self.ch_axis
            y = x.permute(new_axis_list)
            y = torch.flatten(y, start_dim=1)
            min_val_cur, max_val_cur = torch._aminmax(y, 1)

        if self.max_val.numel() <= 1 and self.max_val.isinf():
            self.min_val = min_val_cur
            self.max_val = max_val_cur
        else:
            self.min_val = self.min_val * self.ema_ratio + min_val_cur * (1.0 - self.ema_ratio)
            self.max_val = self.max_val * self.ema_ratio + max_val_cur * (1.0 - self.ema_ratio)
        return x


class PoTModeObserver(ObserverBase):
    r"""Records the most frequent Potscale of ``x``."""
    """
    Borrow from vitis
    https://github.com/Xilinx/Vitis-AI/blob/master/tools/Vitis-AI-Quantizer/vai_q_pytorch/pytorch_binding/pytorch_nndct/quantization/torchquantizer.py
    """

    def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False,
                 quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, factory_kwargs=None):
        super(PoTModeObserver, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max, ch_axis, pot_scale, factory_kwargs)
        self.quant_type = None
        self.counter = [0] * 20

    def forward(self, x_orig):
        r"""Records the running minimum and maximum of ``x``."""
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.to(self.min_val.dtype)
        if self.ch_axis == -1:
            min_val_cur, max_val_cur = torch._aminmax(x)
        else:
            from global_placeholder import logger
            logger.warn('The per-tensor observer does not support per-channel min-max!')
            min_val_cur, max_val_cur = torch._aminmax(x)

        self.min_val = min_val_cur
        self.max_val = max_val_cur
        self._x = x
        return x

    def calculate_qparams(self):
        if self.quant_type is None:
            raise ValueError('You should set the observer type before forward!')
        else:
            scale_range = 1 if self.quant_type == 'input' else 5
            mth = 3 if self.quant_type == 'param' else 2
        scale, zero_point = self._calculate_qparams(self.min_val, self.max_val)
        if self.quant_min != 0:
            
            scale.data = scale.data * 0 + max(self.min_val / self.quant_min, self.max_val / self.quant_max)
        else:
            # 说明就是非负对称量化
            scale.data = scale.data * 0 + self.max_val / self.quant_max
            
        if scale < 2 ** -15:
            max_scale = 0
        else:
            max_scale = 1 / scale
            max_scale = torch.floor(max_scale.log2())
        min_loss = torch.tensor([float('inf')])
        final_scale = max_scale
        max_scale = int(max_scale)  # 这里出问题了
        for s in range(max_scale, max_scale + scale_range):
            _s = 1 / 2 ** s
            if mth == 3:
                new_x = _s * torch.clamp(torch.round(self._x / _s), self.quant_min, self.quant_max)
            elif mth == 2:
                new_x = torch.clamp(self._x / _s, self.quant_min, self.quant_max)
                new_x = torch.where((new_x < 0) & (new_x - new_x.floor() == 0.5), new_x.ceil(), new_x.round())
                new_x *= _s
            loss = ((new_x - self._x)**2).sum()
            min_loss = min_loss.to(loss.device)
            if loss < min_loss:
                min_loss = loss
                final_scale = s
        final_scale = min(final_scale, 12)
        self.counter[final_scale + 7] += 1
        final_scale = self.counter.index(max(self.counter)) - 7
        scale = scale.data * 0 + 1 / (2 ** final_scale)
        zero_point = torch.zeros_like(zero_point)
        if not is_symmetric_quant(self.qscheme):
            if self.min_val >= 0.:
                zero_point = self.quant_min - torch.round(self.min_val / scale)
        sync_tensor(scale)
        sync_tensor(zero_point)
        return scale, zero_point

    def set_quant_type(self, qtype):
        self.quant_type = qtype


class EMAQuantileObserver(ObserverBase):
    """Moving average quantile among batches.
    """

    def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False,
                 quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, ema_ratio=0.9,
                 threshold=0.99999, bins=2048, factory_kwargs=None):
        super(EMAQuantileObserver, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max,
                                                  ch_axis, pot_scale, factory_kwargs)
        assert self.ch_axis == -1, "Quantile observer only support in per-tensor scheme."
        self.ema_ratio = ema_ratio
        self.threshold = threshold
        self.bins = bins

    def forward(self, x_orig):
        r"""Records the running minimum and maximum of ``x``."""
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.to(self.min_val.dtype)
        min_val_cur, max_val_cur = torch._aminmax(x)
        max_hist_range = torch.max(-min_val_cur, max_val_cur)
        hist = torch.histc(torch.abs(x), bins=self.bins, min=0., max=max_hist_range)
        cur_total = 0
        clip_value = max_hist_range
        for i, cnt in enumerate(hist):
            if cur_total + cnt >= self.threshold * x.numel():
                clip_value = (i + 0.5) * (max_hist_range / self.bins)
                break
            cur_total += cnt

        if self.max_val.numel() <= 1 and self.max_val.isinf():
            self.min_val = max(min_val_cur, -clip_value)
            self.max_val = min(max_val_cur, clip_value)
        else:
            self.min_val = self.min_val * self.ema_ratio + max(min_val_cur, -clip_value) * (1.0 - self.ema_ratio)
            self.max_val = self.max_val * self.ema_ratio + min(max_val_cur, clip_value) * (1.0 - self.ema_ratio)
        return x


class ClipStdObserver(ObserverBase):
    """Clip std.
    """

    def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False,
                 quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, std_scale=2.6,
                 factory_kwargs=None):
        super(ClipStdObserver, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max,
                                              ch_axis, pot_scale, factory_kwargs=None)
        self.std_scale = std_scale

    def forward(self, x_orig):
        r"""Records the running minimum and maximum of ``x``."""
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.to(self.min_val.dtype)
        if self.ch_axis == -1:
            min_val_cur, max_val_cur = torch._aminmax(x)
            mean = x.mean()
            std = x.std()
        else:
            x_dim = x.size()
            new_axis_list = [i for i in range(len(x_dim))]
            new_axis_list[self.ch_axis] = 0
            new_axis_list[0] = self.ch_axis
            y = x.permute(new_axis_list)
            y = torch.flatten(y, start_dim=1)
            min_val_cur, max_val_cur = torch._aminmax(y, 1)
            mean = y.mean(1)
            std = y.std(1)

        # using statistics to clip min and max
        min_val = torch.minimum(mean - self.std_scale * std, min_val_cur)
        max_val = torch.maximum(mean + self.std_scale * std, max_val_cur)

        self.min_val = min_val
        self.max_val = max_val

        return x


class LSQObserver(ObserverBase):
    '''
    LSQ observer.
    '''

    def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False,
                 quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, factory_kwargs=None):
        super(LSQObserver, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max,
                                          ch_axis, pot_scale, factory_kwargs)
        self.tensor_norm = None

    def forward(self, x_orig):
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.to(self.min_val.dtype)
        if self.ch_axis == -1:
            self.tensor_norm = x.abs().mean()
            self.min_val, self.max_val = torch._aminmax(x)
        else:
            # compute channel-wise mean
            x_dim = x.size()
            new_axis_list = [i for i in range(len(x_dim))]
            new_axis_list[self.ch_axis] = 0
            new_axis_list[0] = self.ch_axis
            y = x.permute(new_axis_list)
            y = torch.flatten(y, start_dim=1)
            self.tensor_norm = y.abs().mean(1)
            self.min_val, self.max_val = torch._aminmax(y, 1)  # TODO 会运行到这吗

        return x

    def calculate_qparams(self):
        scale = 2 * self.tensor_norm / math.sqrt(self.quant_max)
        zero_point = torch.zeros_like(self.tensor_norm)
        sync_tensor(scale)
        sync_tensor(zero_point)
        if self.pot_scale:
            scale = pot_quantization(scale)
        if not is_symmetric_quant(self.qscheme):
            zero_point = self.quant_min - torch.round(self.min_val / scale)
        return scale, zero_point


class LSQPlusObserver(ObserverBase):
    '''
    LSQ+ observer.
    '''

    def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False,
                 quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, factory_kwargs=None):

        super(LSQPlusObserver, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max,
                                              ch_axis, pot_scale, factory_kwargs)
        self.mean = None
        self.std = None

    def forward(self, x_orig):
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.to(self.min_val.dtype)
        if self.ch_axis == -1:
            self.mean = x.mean()
            self.std = x.std()
            self.min_val, self.max_val = torch._aminmax(x)
        else:
            # compute channel-wise mean
            x_dim = x.size()
            new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
            new_axis_list[self.ch_axis] = 0
            new_axis_list[0] = self.ch_axis
            y = x.permute(new_axis_list)
            y = torch.flatten(y, start_dim=1)
            self.mean = y.mean(1)
            self.std = y.std(1)
            self.min_val, self.max_val = torch._aminmax(y)

        return x

    def calculate_qparams(self):
        scale = torch.maximum((self.mean - 3 * self.std).abs(),
                              (self.mean + 3 * self.std).abs()) / (self.quant_max - self.quant_min + 1)
        sync_tensor(scale)
        sync_tensor(zero_point)
        if self.pot_scale:
            scale = pot_quantization(scale)
        zero_point = torch.zeros_like(self.mean)
        if not is_symmetric_quant(self.qscheme):
            if self.min_val >= 0.:
                zero_point = self.quant_min - torch.round(self.min_val / scale)
        return scale, zero_point


class MSEObserver(ObserverBase):
    '''
    Calculate mseobserver of whole calibration dataset.
    '''

    def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
                 reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, p=2.0,
                 factory_kwargs=None):
        super(MSEObserver, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max,
                                          ch_axis, pot_scale, factory_kwargs)
        self.p = p

    def lp_loss(self, pred, tgt, dim=None):
        """
        loss function measured in L_p Norm
        """
        return (pred - tgt).abs().pow(self.p).mean(dim) if dim else (pred - tgt).abs().pow(self.p).mean()


    def mse(self, x: torch.Tensor, x_min: torch.Tensor, x_max: torch.Tensor, iter=80):
        best_score = 1e+10
        best_min, best_max = torch.tensor([1.0], dtype=torch.float), torch.tensor([1.0], dtype=torch.float)
        best_min.copy_(x_min)
        best_max.copy_(x_max)
        for i in range(iter):
            new_min = x_min * (1.0 - (i * 0.01))
            new_max = x_max * (1.0 - (i * 0.01))
            scale, zero_point = self._calculate_qparams(new_min, new_max)
            x_q = torch.fake_quantize_per_tensor_affine(
                x, scale.item(), int(zero_point.item()),
                self.quant_min, self.quant_max)
            score = self.lp_loss(x_q, x)
            if score < best_score:
                best_score = score
                best_min, best_max = new_min, new_max
        return best_min, best_max

    def mse_perchannel(self, x: torch.Tensor, x_min: torch.Tensor, x_max: torch.Tensor, iter=80, ch_axis=0):
        assert x_min.shape == x_max.shape
        assert ch_axis >= 0, f'{ch_axis}'
        best_score = 1e+10 * torch.ones_like(x_min)
        best_min, best_max = x_min.clone(), x_max.clone()
        reduce_dim = tuple([i for i in range(len(x.shape)) if i != ch_axis])
        for i in range(iter):
            new_min = x_min * (1.0 - (i * 0.01))
            new_max = x_max * (1.0 - (i * 0.01))
            scale, zero_point = self._calculate_qparams(new_min, new_max)
            x_q = torch.fake_quantize_per_channel_affine(
                x, scale, zero_point.long() if _version_under_1100 else zero_point, ch_axis, 
                self.quant_min, self.quant_max)
            score = self.lp_loss(x_q, x, reduce_dim)
            update_idx = (score < best_score)
            best_score[update_idx] = score[update_idx]
            best_min[update_idx] = new_min[update_idx]
            best_max[update_idx] = new_max[update_idx]
        return best_min, best_max

    def forward(self, x_orig):
        r"""Records the running minimum and maximum of ``x``."""
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.clone().detach().to(self.min_val.dtype)
        if self.ch_axis == -1:
            min_val_cur, max_val_cur = torch._aminmax(x)
            min_val_cur, max_val_cur = self.mse(x, min_val_cur, max_val_cur, iter=95)
        else:
            x_dim = x.size()
            new_axis_list = [i for i in range(len(x_dim))]
            new_axis_list[self.ch_axis] = 0
            new_axis_list[0] = self.ch_axis
            x_channel = x.permute(new_axis_list)
            y = torch.flatten(x_channel, start_dim=1)
            min_val_cur, max_val_cur = torch._aminmax(y, 1)
            min_val_cur, max_val_cur = self.mse_perchannel(x, min_val_cur, max_val_cur, iter=80, ch_axis=self.ch_axis)

        self.min_val = torch.min(self.min_val, min_val_cur)
        self.max_val = torch.max(self.max_val, max_val_cur)
        return x


class EMAMSEObserver(ObserverBase):
    '''
    Calculate mseobserver of whole calibration dataset.
    '''
    def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
                 reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False,
                 p=2.0, ema_ratio=0.9, factory_kwargs=None):
        super(EMAMSEObserver, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max,
                                             ch_axis, pot_scale, factory_kwargs)
        self.ema_ratio = ema_ratio
        self.p = p

    def lp_loss(self, pred, tgt, dim=None):
        """
        loss function measured in L_p Norm
        """
        return (pred - tgt).abs().pow(self.p).mean(dim) if dim else (pred - tgt).abs().pow(self.p).mean()

    def mse(self, x: torch.Tensor, x_min: torch.Tensor, x_max: torch.Tensor, iter=80):
        best_score = 1e+10
        best_min, best_max = torch.tensor([1.0], dtype=torch.float), torch.tensor([1.0], dtype=torch.float)
        best_min.copy_(x_min)
        best_max.copy_(x_max)
        for i in range(iter):
            new_min = x_min * (1.0 - (i * 0.01))
            new_max = x_max * (1.0 - (i * 0.01))
            scale, zero_point = self._calculate_qparams(new_min, new_max)  # 算出s和z（都是tensor，妙）。对称量化下就是取绝对值max做为。
            x_q = torch.fake_quantize_per_tensor_affine(
                x, scale.item(), int(zero_point.item()),
                self.quant_min, self.quant_max)
            score = self.lp_loss(x_q, x)  # 和欧式不一样！这是灵魂是求均值
            if score < best_score:
                best_score = score
                best_min, best_max = new_min, new_max
        return best_min, best_max

    def mse_perchannel(self, x: torch.Tensor, x_min: torch.Tensor, x_max: torch.Tensor, iter=80, ch_axis=0):
        assert x_min.shape == x_max.shape
        assert ch_axis >= 0, f'{ch_axis}'
        best_score = 1e+10 * torch.ones_like(x_min)
        best_min, best_max = x_min.clone(), x_max.clone()
        reduce_dim = tuple([i for i in range(len(x.shape)) if i != ch_axis])
        for i in range(iter):
            new_min = x_min * (1.0 - (i * 0.01))
            new_max = x_max * (1.0 - (i * 0.01))
            scale, zero_point = self._calculate_qparams(new_min, new_max)
            x_q = torch.fake_quantize_per_channel_affine(
                x, scale, zero_point.long() if _version_under_1100 else zero_point, ch_axis, 
                self.quant_min, self.quant_max)
            score = self.lp_loss(x_q, x, reduce_dim)
            update_idx = (score < best_score)
            best_score[update_idx] = score[update_idx]
            best_min[update_idx] = new_min[update_idx]
            best_max[update_idx] = new_max[update_idx]
        return best_min, best_max

    def forward(self, x_orig):
        r"""Records the running minimum and maximum of ``x``."""
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.clone().detach().to(self.min_val.dtype)
        if self.ch_axis == -1:
            min_val_cur, max_val_cur = torch._aminmax(x)
            min_val_cur, max_val_cur = self.mse(x, min_val_cur, max_val_cur, iter=95)   # 因为是求MSE呀，所以是从minmax两边开始搜索
        else:
            x_dim = x.size()
            new_axis_list = [i for i in range(len(x_dim))]
            new_axis_list[self.ch_axis] = 0
            new_axis_list[0] = self.ch_axis
            x_channel = x.permute(new_axis_list)
            y = torch.flatten(x_channel, start_dim=1)
            min_val_cur, max_val_cur = torch._aminmax(y, 1)
            min_val_cur, max_val_cur = self.mse_perchannel(x, min_val_cur, max_val_cur, iter=80, ch_axis=self.ch_axis)

        if self.max_val.numel() <= 1 and self.max_val.isinf():  # 初次，更新最值
            self.min_val = min_val_cur
            self.max_val = max_val_cur
        else:  # 再次，通过滑动EMA型更新
            self.min_val = self.min_val * self.ema_ratio + min_val_cur * (1.0 - self.ema_ratio)
            self.max_val = self.max_val * self.ema_ratio + max_val_cur * (1.0 - self.ema_ratio)
        return x
