import torch as t

from .quantizer import Quantizer
from .range_estimators import MSE_Estimator
def round_pass(x):
    y = x.round()
    y_grad = x
    return (y - y_grad).detach() + y_grad

class EWGS_discretizer(t.autograd.Function):
    """
    """
    @staticmethod
    def forward(ctx, x_in, diff, delta):
        ctx.save_for_backward(diff, delta)
        # ctx.save_for_backward(delta)
        return x_in
    @staticmethod
    def backward(ctx, g):
        diff = ctx.saved_tensors[0]
        delta = ctx.saved_tensors[1]
        # delta = 1e-3
        delta = float(delta)
        
        scale = 1 + delta * t.sign(g)*diff
        return g * scale, None, None

def grad_scale(x, scale):
    y = x
    y_grad = x * scale
    return (y - y_grad).detach() + y_grad


class LsqQuan(Quantizer):
    def __init__(self, bit, all_positive=False, symmetric=False, per_channel=False, ewgs=None):
        super().__init__(bit)
        self.bit = bit
        self.all_positive = all_positive
        self.symmetric = symmetric
        self.per_channel = per_channel
        self.ewgs = ewgs
        print(ewgs)
        if all_positive:
            assert not symmetric, "Positive quantization cannot be symmetric"
            # unsigned activation is quantized to [0, 2^b-1]
            self.thd_neg = 0
            self.thd_pos = 2 ** bit - 1
        else:
            if symmetric:
                # signed weight/activation is quantized to [-2^(b-1)+1, 2^(b-1)-1]
                self.thd_neg = - 2 ** (bit - 1) + 1
                self.thd_pos = 2 ** (bit - 1) - 1
            else:
                # signed weight/activation is quantized to [-2^(b-1), 2^(b-1)-1]
                self.thd_neg = - 2 ** (bit - 1)
                self.thd_pos = 2 ** (bit - 1) - 1

        self.per_channel = per_channel
        self.s = t.nn.Parameter(t.ones(1))
        self.EWGS_discretizer = EWGS_discretizer.apply
    def compute_thd(self, bits):
        if self.all_positive:
            assert not self.symmetric, "Positive quantization cannot be symmetric"
            # unsigned activation is quantized to [0, 2^b-1]
            thd_neg = 0
            thd_pos = 2 ** bits - 1
        else:
            if self.symmetric:
                # signed weight/activation is quantized to [-2^(b-1)+1, 2^(b-1)-1]
                thd_neg = - 2 ** (bits - 1) + 1
                thd_pos = 2 ** (bits - 1) - 1
            else:
                # signed weight/activation is quantized to [-2^(b-1), 2^(b-1)-1]
                thd_neg = - 2 ** (bits - 1)
                thd_pos = 2 ** (bits - 1) - 1
        
        if isinstance(thd_neg, t.Tensor):
            thd_neg = int(thd_neg.cpu().item())
            thd_pos = int(thd_pos.cpu().item())
        elif isinstance(thd_neg, float):
            thd_neg = int(thd_neg)
            thd_pos = int(thd_pos)
        
        return thd_neg, thd_pos

    def init_from(self, x, *args, **kwargs):
        if self.per_channel:
            self.s = t.nn.Parameter(
                x.detach().abs().mean(dim=list(range(1, x.dim())), keepdim=True) * 2 / (self.thd_pos ** 0.5))
        else:
            scale_estimator = MSE_Estimator(per_channel=self.per_channel, quantizer=self, n_bits=self.bit)
            xmin, xmax = scale_estimator(x)
            s_init = self.set_quant_range(xmin, xmax, self.bit)
            self.s.data.copy_(s_init)
            # self.s = t.nn.Parameter(x.detach().abs().mean() * 2 / (self.thd_pos ** 0.5))
    def _tensorize_min_max(self, x_min, x_max):
        """
        Converts provided min max range into tensors
        Parameters
        ----------
        x_min: float or PyTorch 1D tensor
        x_max: float or PyTorch 1D tensor

        Returns
        -------
        x_min: PyTorch Tensor 0 or 1-D
        x_max: PyTorch Tensor 0 or 1-D
        """
        # Ensure a torch tensor
        if not t.is_tensor(x_min):
            x_min = t.tensor(x_min).float()
            x_max = t.tensor(x_max).float()

        if x_min.dim() > 0 and len(x_min) > 1 and not self.per_channel:
            print(x_min)
            print(self.per_channel)
            raise ValueError(
                "x_min and x_max must be a float or 1-D Tensor"
                " for per-tensor quantization (per_channel=False)"
            )
        # Ensure we always use zero and avoid division by zero
        x_min = t.min(x_min, t.zeros_like(x_min))
        x_max = t.max(x_max, t.ones_like(x_max) * 1e-8)

        return x_min, x_max
    def set_quant_range(self, x_min, x_max, b):
        self.x_min_fp32, self.x_max_fp32 = x_min, x_max
        x_min, x_max = self._tensorize_min_max(x_min, x_max)
        self._signed = x_min.min() < 0

        x_absmax = t.max(x_min.abs(), x_max)
        int_min, int_max = self.compute_thd(b)
        _delta = (x_absmax / int_max).detach()
        _delta = t.tensor(_delta.item(), device=_delta.device)
        self.s.data.copy_(_delta)
 
        return _delta
    
    def forward(self, x):
        # if self.per_channel:
        #     s_grad_scale = 1.0 / ((self.thd_pos * x.numel()) ** 0.5)
        # else:
        #     s_grad_scale = 1.0 / ((self.thd_pos * x.numel()) ** 0.5)
        # s_scale = grad_scale(self.s, s_grad_scale)
        s_scale = self.s
        # 生成和 x 形状相同的噪声
        
        # if self.s.grad != None:
        #     noise = t.randn_like(x) * (s_scale / 2)
        #     # 将噪声添加到 x 上，并乘以噪声的标准差，再加上均值
        #     x = x + noise
        x_in = x
        x = x / s_scale 
        x = t.clamp(x, self.thd_neg, self.thd_pos)
        x = round_pass(x)
        x = x * s_scale
        diff = (x_in.detach() - x.detach())
        if self.ewgs != 'None':
            x = self.EWGS_discretizer(x, diff, t.Tensor([self.ewgs]))
        return x
