import logging
import numpy as np
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union
import numpy as np
logger = logging.getLogger(__name__)


class StraightThrough(nn.Module):
    def __init__(self, channel_num: int = 1):
        super().__init__()

    def forward(self, input):
        return input

def smooth_value(x: torch.Tensor):
    max_value = torch.max(x).cpu().item()
    min_value = torch.min(x).cpu().item()
    q1 = np.percentile(x.cpu().numpy(), 25)
    q2 = np.percentile(x.cpu().numpy(), 50)
    q3 = np.percentile(x.cpu().numpy(), 75)
    return min_value, q1, q2, q3, max_value

def round_ste(x: torch.Tensor):
    """
    Implement Straight-Through Estimator for rounding operation.
    """
    return (x.round() - x).detach() + x

def floor_ste(x: torch.Tensor):
    """
    Implement Straight-Through Estimator for floor operation.
    """
    return (x.floor() - x).detach() + x

def lp_loss(pred, tgt, p=2.0, reduction='none'):
    """
    loss function measured in L_p Norm
    """
    if reduction == 'none':
        return (pred-tgt).abs().pow(p).sum(1).mean()
    else:
        return (pred-tgt).abs().pow(p).mean()

class UniformAffineQuantizer(nn.Module):
    """
    PyTorch Function that can be used for asymmetric quantization (also called uniform affine
    quantization). Quantizes its argument in the forward pass, passes the gradient 'straight
    through' on the backward pass, ignoring the quantization that occurred.
    Based on https://arxiv.org/abs/1806.08342.

    :param n_bits: number of bit for quantization
    :param symmetric: if True, the zero_point should always be 0
    :param channel_wise: if True, compute scale and zero_point in each channel
    :param scale_method: determines the quantization scale and zero point
    """
    def __init__(self, n_bits: int = 8, symmetric: bool = False, channel_wise: bool = False, scale_method: str = 'max',
                 leaf_param: bool = False, always_zero: bool = False, prob: float = 1.0, num_timesteps: int = 100, smooth_type: str = 'normal'):
        super(UniformAffineQuantizer, self).__init__()
        self.sym = symmetric
        # assert 2 <= n_bits <= 8, 'bitwidth not supported'
        self.n_bits = n_bits
        self.n_levels = 2 ** self.n_bits if not self.sym else 2 ** (self.n_bits - 1) - 1
        # self.n_levels = 2 ** self.n_bits
        self.delta = None
        self.zero_point = None
        self.inited = False
        self.leaf_param = leaf_param
        self.channel_wise = channel_wise
        self.scale_method = scale_method
        self.running_stat = False
        self.always_zero = always_zero
        if self.leaf_param:
            self.x_min, self.x_max = None, None
        """for activation quantization"""
        self.running_min = None
        self.running_max = None
        self.running_scale = None
        self.running_zero = None

        """mse params"""
        # self.scale_method = scale_method
        self.one_side_dist = None
        self.num = 100
        self.eps = torch.tensor(1e-8, dtype=torch.float32)
        """do like dropout"""
        self.prob = prob
        self.is_training = False
        """DM para"""
        self.timesteps = num_timesteps
        self.time = 0
        self.t = None
        self.is_mix_steps = False
        self.list_delta = [None] * self.timesteps
        self.stop_train = False

    def set_inited(self, inited: bool = True):  # inited manually
        self.inited = inited
    def set_time(self, time: int = 0):  # inited manually
        self.time = time
    def set_t(self, t):
        self.t = t
    def set_steps_state(self, is_mix_steps: bool = True):
        self.is_mix_steps = is_mix_steps

    def update_quantize_range(self, x_scale, x_zero):
        if self.running_scale is None:
            self.running_scale = x_scale
            self.running_zero = x_zero
        self.running_scale = 0.1 * x_scale + 0.9 * self.running_scale
        self.running_zero = 0.1 * x_zero + 0.9 * self.running_zero
        return self.running_scale, self.running_zero

    def lp_loss(self, pred, tgt, p=2.0):
        x = (pred - tgt).abs().pow(p)
        if not self.channel_wise:
            return x.mean()
        else:
            y = torch.flatten(x, 1)
            return y.mean(1)

    def calculate_qparams(self, min_val, max_val):
        # one_dim or one element
        quant_min, quant_max = 0, self.n_levels - 1
        min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
        max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
        scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
        scale = torch.max(scale, self.eps)
        zero_point = quant_min - torch.round(min_val_neg / scale)
        zero_point = torch.clamp(zero_point, quant_min, quant_max)
        return scale, zero_point

    # @profile
    def quantize(self, x: torch.Tensor, x_max, x_min):
        delta, zero_point = self.calculate_qparams(x_min, x_max)
        x_int = torch.round(x / delta)
        x_quant = torch.clamp(x_int + zero_point, 0, self.n_levels - 1)
        x_float_q = (x_quant - zero_point) * delta
        return x_float_q

    def search_quantize(self, x: torch.Tensor, x_max, x_min):
        delta, zero_point = self.calculate_qparams(x_min, x_max)
        if self.channel_wise:
            new_shape = [1] * len(x.shape)
            new_shape[0] = x.shape[0]
            delta = delta.reshape(new_shape)
            zero_point = zero_point.reshape(new_shape)
        x_int = torch.round(x / delta)
        x_quant = torch.clamp(x_int + zero_point, 0, self.n_levels - 1)
        x_float_q = (x_quant - zero_point) * delta
        return x_float_q

    def get_scale_zeropoint(self, x):
        if self.scale_method != "mse":
            raise NotImplementedError
        scale, zero_point = self.init_quantization_scale_2(x, self.channel_wise)
        return scale, zero_point

    def init_quantization_scale_channel(self, x: torch.Tensor):
        with torch.no_grad():
            scale, zero_point = self.get_scale_zeropoint(x)
            if self.leaf_param:
                scale, zero_point = self.update_quantize_range(scale, zero_point)
        return scale, zero_point

    def init_quantization_scale_1(
        self, x_clone: torch.Tensor, channel_wise: bool = False
    ):
        if channel_wise:
            # determine the scale and zero point channel-by-channel
            delta, zero_point = self.init_quantization_scale_channel(x_clone)
            new_shape = [1] * len(x_clone.shape)
            new_shape[0] = x_clone.shape[0]
            delta = delta.reshape(new_shape)
            zero_point = zero_point.reshape(new_shape)
        else:
            delta, zero_point = self.init_quantization_scale_channel(x_clone)
        return delta, zero_point

    def forward(self, x: torch.Tensor):
        if self.inited is False:
            if self.leaf_param:
                if self.scale_method != 'channel_time' and self.channel_wise is False:
                    delta, zero_point = self.init_quantization_scale_1(x, self.channel_wise)
                elif self.scale_method != 'channel_time' and self.channel_wise: #
                    delta, zero_point = self.init_quantization_scale_for_activation(x, self.channel_wise)
                elif self.scale_method == 'channel_time':
                    delta, zero_point = self.init_quantization_scale_for_same_activation(x, self.channel_wise)
                else:
                    raise NotImplementedError
                if self.delta == None:
                    self.delta = torch.nn.Parameter(torch.stack([delta] * self.timesteps, dim=0))
                    self.zero_point = torch.nn.Parameter(torch.stack([zero_point] * self.timesteps, dim=0))
                else:
                    self.delta[self.time] = delta
                    self.zero_point[self.time] = zero_point
            else:
                if self.scale_method == 'mse':
                    self.delta, self.zero_point = self.init_quantization_scale_1(x, self.channel_wise)
                elif self.scale_method == 'max':
                    self.delta, self.zero_point = self.init_quantization_scale_2(x, self.channel_wise)
                else:
                    raise NotImplementedError

        # start quantization
        if self.leaf_param:
            if self.is_mix_steps: # parallel quantization during training
                if len(x.shape) == 3:
                    assert x.size(0) % len(self.t)==0
                    n_heads = x.size(0) // len(self.t)
                else:
                    assert x.size(0) == len(self.t)
                if self.channel_wise:
                    if len(x.shape) == 3:
                        cur_delta = self.delta[self.t].squeeze(1)
                        zero_point = self.zero_point[self.t].squeeze(1)
                        cur_delta = torch.repeat_interleave(cur_delta, repeats=n_heads, dim=0)
                        zero_point = torch.repeat_interleave(zero_point, repeats=n_heads, dim=0)
                    else:
                        cur_delta = self.delta[self.t].squeeze(1)
                        zero_point = self.zero_point[self.t].squeeze(1)
                else:
                    cur_delta = self.delta[self.t]
                    zero_point = self.zero_point[self.t]
                    for dim in range(x.dim()-1):
                        cur_delta = cur_delta.unsqueeze(1)
                        zero_point = zero_point.unsqueeze(1)
                    if len(x.shape) == 3:
                        cur_delta = torch.repeat_interleave(cur_delta, repeats=n_heads, dim=0)
                        zero_point = torch.repeat_interleave(zero_point, repeats=n_heads, dim=0)
                if self.scale_method == 'channel_time':
                    x_int = torch.full_like(x, 1)
                    x_quant = x_int
                    x_dequant = x_quant * cur_delta
                else:
                    x_int = round_ste(x / cur_delta) + round_ste(zero_point)
                    if self.sym:
                        x_quant = torch.clamp(x_int, -self.n_levels - 1, self.n_levels)
                    else:
                        x_quant = torch.clamp(x_int, 0, self.n_levels - 1)
                    x_dequant = (x_quant - round_ste(zero_point)) * cur_delta
            else: # single quantization during sampling
                if self.scale_method == 'channel_time':
                    x_int = torch.full_like(x, 1)
                    x_quant = x_int
                    x_dequant = x_quant * self.delta[self.time]
                else:
                    x_int = round_ste(x / self.delta[self.time]) + round_ste(self.zero_point[self.time])
                    if self.sym:
                        x_quant = torch.clamp(x_int, -self.n_levels - 1, self.n_levels)
                    else:
                        x_quant = torch.clamp(x_int, 0, self.n_levels - 1)
                    x_dequant = (x_quant - round_ste(self.zero_point[self.time])) * self.delta[self.time]
        else:
            x_int = round_ste(x / self.delta) + round_ste(self.zero_point)
            if self.sym:
                x_quant = torch.clamp(x_int, -self.n_levels - 1, self.n_levels)
            else:
                x_quant = torch.clamp(x_int, 0, self.n_levels - 1)
            x_dequant = (x_quant - round_ste(self.zero_point)) * self.delta

        if self.is_training and self.prob < 1.0:
            x_ans = torch.where(torch.rand_like(x) < self.prob, x_dequant, x)
        else:
            x_ans = x_dequant
        return x_ans

    def init_quantization_scale_2(self, x: torch.Tensor, channel_wise: bool = False):
        delta, zero_point = None, None
        if channel_wise:
            x_clone = x.clone().detach()
            n_channels = x_clone.shape[0]
            if len(x.shape) == 4:
                x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0]
            elif len(x.shape) == 3:
                x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0]
            else:
                x_max = x_clone.abs().max(dim=-1)[0]
            delta = x_max.clone()
            zero_point = x_max.clone()
            # determine the scale and zero point channel-by-channel
            for c in range(n_channels):
                delta[c], zero_point[c] = self.init_quantization_scale_2(x_clone[c], channel_wise=False)
            if len(x.shape) == 4:
                delta = delta.view(-1, 1, 1, 1)
                zero_point = zero_point.view(-1, 1, 1, 1)
            elif len(x.shape) == 3:
                delta = delta.view(-1, 1, 1)
                zero_point = zero_point.view(-1, 1, 1)
            else:
                delta = delta.view(-1, 1)
                zero_point = zero_point.view(-1, 1)
        else:
            if self.leaf_param:
                self.x_min = x.data.min()
                self.x_max = x.data.max()
            if 'max' in self.scale_method:
                x_min = min(x.min().item(), 0)
                x_max = max(x.max().item(), 0)
                if x_max == x_min:
                    delta = x_max
                    zero_point = 0 / delta
                    return delta, zero_point
                if 'scale' in self.scale_method:
                    x_min = x_min * (self.n_bits + 2) / 8
                    x_max = x_max * (self.n_bits + 2) / 8

                x_absmax = max(abs(x_min), x_max)
                if self.sym:
                    # x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax
                    delta = x_absmax / self.n_levels
                else:
                    delta = float(x.max().item() - x.min().item()) / (self.n_levels - 1)
                if delta < 1e-8:
                    warnings.warn('Quantization range close to zero: [{}, {}]'.format(x_min, x_max))
                    delta = 1e-8

                zero_point = round(-x_min / delta) if not (self.sym or self.always_zero) else 0
                delta = torch.tensor(delta).type_as(x)
            elif self.scale_method == 'mse':
                x_max = x.max()
                x_min = x.min()
                if x_max == x_min:
                    delta = x_max
                    zero_point = 0 / delta
                    return delta, zero_point
                best_score = 1e+10
                for i in range(0,80,2):
                    new_max = x_max * (1.0 - (i * 0.01))
                    new_min = x_min * (1.0 - (i * 0.01))
                    x_q = self.quantize(x, new_max, new_min)
                    # L_p norm minimization as described in LAPQ
                    # https://arxiv.org/abs/1911.07190
                    score = lp_loss(x, x_q, p=2.4, reduction='all')
                    if score < best_score:
                        best_score = score
                        delta = (new_max - new_min) / (2 ** self.n_bits - 1) \
                            if not self.always_zero else new_max / (2 ** self.n_bits - 1)
                        zero_point = (- new_min / delta).round() if not self.always_zero else 0
            else:
                raise NotImplementedError
        return delta, zero_point

    def init_quantization_scale_for_activation(self, x: torch.Tensor, channel_wise: bool = False):
        delta, zero_point = None, None
        if channel_wise:
            x_clone = x.clone().detach()
            if len(x.shape) == 4:
                n_channels = x_clone.shape[1]
            elif len(x.shape) == 3: 
                n_channels = x_clone.shape[1] # tokenwise quantization for QKV
                # n_channels = x_clone.shape[-1] # tokenwise quantization for linear layers
            else:
                n_channels = x_clone.shape[1]

            if len(x.shape) == 4:
                x_max = x_clone.abs().max(dim=0)[0].max(dim=1)[0].max(dim=1)[0]
            elif len(x.shape) == 2:
                x_max = x_clone.abs().max(dim=0)[0]
            elif len(x.shape) == 3:
                x_max = x_clone.abs().max(dim=0)[0].max(dim=1)[0]
                # x_max = x_clone.abs().max(dim=0)[0].max(dim=0)[0]
            else:
                raise NotImplementedError

            delta = x_max.clone()
            zero_point = x_max.clone()
            # determine the scale and zero point channel-by-channel
            for c in range(n_channels):
                if len(x.shape) == 3:
                    delta[c], zero_point[c] = self.init_quantization_scale_for_activation(x_clone[:, c, :], channel_wise=False)
                elif len(x.shape) == 4:
                    delta[c], zero_point[c] = self.init_quantization_scale_for_activation(x_clone[:,c,:,:], channel_wise=False)
                else:
                    delta[c], zero_point[c] = self.init_quantization_scale_for_activation(x_clone[:, c], channel_wise=False)
            if len(x.shape) == 4:
                delta = delta.view(1, -1, 1, 1)
                zero_point = zero_point.view(1, -1, 1, 1)
            elif len(x.shape) == 2:
                delta = delta.view(1, -1)
                zero_point = zero_point.view(1, -1)
            elif len(x.shape) == 3:
                delta = delta.view(1, -1, 1)
                zero_point = zero_point.view(1, -1, 1)
            else:
                raise NotImplementedError
        else:
            if self.leaf_param:
                self.x_min = x.data.min()
                self.x_max = x.data.max()

            if "max" in self.scale_method:
                x_min = min(x.min().item(), 0)
                x_max = max(x.max().item(), 0)

                if 'scale' in self.scale_method:
                    x_min = x_min * (self.n_bits + 2) / 8
                    x_max = x_max * (self.n_bits + 2) / 8

                x_absmax = max(abs(x_min), x_max)
                if self.sym:
                    # x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax
                    delta = x_absmax / self.n_levels
                else:
                    delta = float(x.max().item() - x.min().item()) / (self.n_levels - 1)
                if delta < 1e-8:
                    warnings.warn('Quantization range close to zero: [{}, {}]'.format(x.max().item(), x.min().item()))
                    delta = 1e-8

                zero_point = round(-x_min / delta) if not (self.sym or self.always_zero) else 0
                delta = torch.tensor(delta).type_as(x)
                zero_point = torch.tensor(zero_point).type_as(x)
            else:
                x_clone = x.clone().detach()
                x_max = x_clone.max()
                x_min = x_clone.min()
                if x_max == x_min:
                    delta = x_max
                    zero_point = 0 / delta
                    return delta, zero_point
                best_score = 1e+10
                self.x_min = x_min
                self.x_max = x_max
                # RepQ method
                for pct in [0.999, 0.9999, 0.99999]:
                    try:
                        new_max = torch.quantile(x_clone.reshape(-1), pct)
                        new_min = torch.quantile(x_clone.reshape(-1), 1.0 - pct)
                    except:
                        new_max = torch.tensor(np.percentile(
                            x_clone.reshape(-1).cpu(), pct * 100),
                            device=x_clone.device,
                            dtype=torch.float32)
                        new_min = torch.tensor(np.percentile(
                            x_clone.reshape(-1).cpu(), (1 - pct) * 100),
                            device=x_clone.device,
                            dtype=torch.float32)   
                    x_q = self.quantize(x_clone, new_max, new_min)
                    score = lp_loss(x_clone, x_q, p=2, reduction='all')
                    # score = new_lp_loss(x_clone.view(x_clone.shape[0], -1), x_q.view(x_q.shape[0], -1))

                    if score < best_score:
                        best_score = score
                        delta = (new_max - new_min) / (2 ** self.n_bits - 1)
                        delta = torch.clamp(delta, min=1e-8)  # TODO: Added, examine effect
                        zero_point = (- new_min / delta).round()
        return delta, zero_point

    def bitwidth_refactor(self, refactored_bit: int):
        self.n_bits = refactored_bit
        self.n_levels = 2 ** self.n_bits

    def extra_repr(self):
        s = 'bit={n_bits}, scale_method={scale_method}, symmetric={sym}, channel_wise={channel_wise},' \
            ' leaf_param={leaf_param}'
        return s.format(**self.__dict__)

class SmoothQuantizer(nn.Module):
    def __init__(self, module_type: str = 'linear', alpha: float = 0.5, timesteps: int = 100, smooth_type: str = 'normal'):
        super(SmoothQuantizer, self).__init__()
        self.module_type = module_type
        self.alpha = alpha
        self.smooth_type = smooth_type
        self.max_input = None
        self.max_weight = None
        self.time = 0
        self.timesteps = timesteps
        self.scales = None
        self.running_scales = None
        self.all_scales = [None]*timesteps
        self.all_var = [None]*timesteps
        self.inited = True

    def update_quantize_range(self, max_input, max_weight):
        if self.max_input is None:
            self.max_input = max_input
            self.max_weight = max_weight
        self.max_input = 0.1 * max_input + 0.9 * self.max_input
        self.max_weight = 0.1 * max_weight + 0.9 * self.max_weight

    def update_scales(self, scales):
        if self.running_scales is None:
            self.running_scales = scales
        self.running_scales = 0.1 * scales + 0.9 * self.running_scales
        return self.running_scales

    def calculate_scales_1(self, input, weight):
        if self.module_type == 'conv2d':
            tensor_permuted = input.permute(1, 0, 2, 3)
            l = tensor_permuted.reshape(tensor_permuted.size(0), -1)
            max_input = l.abs().max(dim=1)[0].clamp(min=1e-5)

            tensor_permuted = weight.permute(1, 0, 2, 3)
            l = tensor_permuted.reshape(tensor_permuted.size(0), -1)
            max_weight = l.abs().max(dim=1)[0].clamp(min=1e-5)

            scales = (
                (max_input.pow(self.alpha) / max_weight.pow(1 - self.alpha))
                .clamp(min=1e-5)
            )
            scales = torch.unsqueeze(scales, dim=0)
            scales = torch.unsqueeze(scales, dim=-1)
            scales = torch.unsqueeze(scales, dim=-1)

        elif self.module_type == 'linear':
            max_input = input.abs().max(dim=0)[0].clamp(min=1e-5)
            max_weight = weight.abs().max(dim=0)[0].clamp(min=1e-5)

            scales = (
                (max_input.pow(self.alpha) / max_weight.pow(1 - self.alpha))
                .clamp(min=1e-5)
            )
            scales = torch.unsqueeze(scales, dim=0)
            
        else:
            raise ValueError(
                "Unexpect quantization module"
            )
        return scales
    def calculate_scales_2(self, input: torch.Tensor, threshold: float):
        if self.module_type == 'conv2d':
            tensor_permuted = input.permute(1, 0, 2, 3)
            l = tensor_permuted.reshape(tensor_permuted.size(0), -1)
            max_input = l.abs().max(dim=1)[0].clamp(min=1e-5)

            maxsv, inds = torch.sort(input.contiguous().view(-1), descending=True)
            basl = int(len(input.contiguous().view(-1))*threshold) # hyperparameter
            baseline =  maxsv[basl]
            
            scales = (max_input / baseline).clamp(min=1e-5)
            scales[max_input<=baseline] = 1.
            scales = torch.unsqueeze(scales, dim=0)
            scales = torch.unsqueeze(scales, dim=-1)
            scales = torch.unsqueeze(scales, dim=-1)
        elif self.module_type == 'linear':
            max_input = input.abs().max(dim=0)[0].clamp(min=1e-5)

            maxsv, inds = torch.sort(input.contiguous().view(-1), descending=True)
            basl = int(len(input.contiguous().view(-1))*threshold) # hyperparameter
            baseline =  maxsv[basl]
            
            scales = (max_input / baseline).clamp(min=1e-5)
            scales[max_input<=baseline] = 1.
            scales = torch.unsqueeze(scales, dim=0)
        else:
            raise ValueError(
                "Unexpect quantization module"
            )
        return scales
    def calculate_scales_3(self, input: torch.Tensor, weight: torch.Tensor, threshold: float):
        if self.module_type == 'conv2d':
            tensor_permuted = input.permute(1, 0, 2, 3)
            l = tensor_permuted.reshape(tensor_permuted.size(0), -1)
            max_input = l.abs().max(dim=1)[0].clamp(min=1e-5)

            tensor_permuted = weight.permute(1, 0, 2, 3)
            l = tensor_permuted.reshape(tensor_permuted.size(0), -1)
            max_weight = l.abs().max(dim=1)[0].clamp(min=1e-5)

            maxsv, inds = torch.sort(max_input, descending=True)
            basl = int(len(max_input)*threshold) # hyperparameter
            scale_inds = inds[:basl]
            scales = torch.ones(weight.size(1)).to(weight.device)
            for i in scale_inds:
                scale = (
                    (max_input[i].pow(self.alpha) / max_weight[i].pow(1-self.alpha))
                    .clamp(min=1e-5)
                )
                scales[i] = scale

            scales = torch.unsqueeze(scales, dim=0)
            scales = torch.unsqueeze(scales, dim=-1)
            scales = torch.unsqueeze(scales, dim=-1)
        elif self.module_type == 'linear':
            max_input = input.abs().max(dim=0)[0].clamp(min=1e-5)
            max_weight = weight.abs().max(dim=0)[0].clamp(min=1e-5)

            maxsv, inds = torch.sort(max_input, descending=True)
            basl = int(len(max_input)*threshold) # hyperparameter
            scale_inds = inds[:basl]
            scales = torch.ones(weight.size(1)).to(weight.device)
            for i in scale_inds:
                scale = (
                    (max_input[i].pow(self.alpha) / max_weight[i].pow(1-self.alpha))
                    .clamp(min=1e-5)
                )
                scales[i] = scale
    
            scales = torch.unsqueeze(scales, dim=0)
        else:
            raise ValueError(
                "Unexpect quantization module"
            )
        return scales
    def calculate_scales_4(self, input: torch.Tensor, weight: torch.Tensor):
        if self.module_type == 'conv2d':
            '''input_channel max-min'''
            tensor_permuted = input.permute(1, 0, 2, 3)
            l = tensor_permuted.reshape(tensor_permuted.size(0), -1)
            max_input = l.abs().max(dim=1)[0].clamp(min=1e-6)

            l = weight.reshape(weight.size(0), weight.size(1), -1)
            weight_abs_max = l.abs().max(dim=2)[0].clamp(min=1e-5)

            weight_max = l.max(dim=2)[0].clamp(min=1e-5)
            weight_min = l.min(dim=2)[0].clamp(max=-1e-5)

            weight_out_max, weight_maxin_indx = weight_max.max(dim=1)
            weight_out_min, weight_minin_indx = weight_min.min(dim=1)
            weight_absin_max, _ = weight_abs_max.max(dim=0)
            scales = torch.ones(weight.size(1)).to(weight.device)
            for i in range(weight.size(1)):
                if torch.any(weight_maxin_indx == i) or torch.any(weight_minin_indx == i) : #the channels include weight out max and min
                    continue
                scale = (
                    (max_input[i].pow(self.alpha) / weight_absin_max[i].pow(1-self.alpha))
                    .clamp(min=1e-5)
                )
                scale_max_limit, _ = (weight_out_max / weight_max[:,i]).min(dim=0)
                scale_min_limit, _ = (weight_out_min / weight_min[:,i]).min(dim=0)
                scales[i], _ = torch.stack([scale, scale_max_limit, scale_min_limit]).min(dim=0)
            scales = torch.unsqueeze(scales, dim=0)
            scales = torch.unsqueeze(scales, dim=-1)
            scales = torch.unsqueeze(scales, dim=-1)
            return scales
        elif self.module_type == 'conv1d':
            '''input_channel max'''
            tensor_permuted = input.permute(1, 0, 2)
            l = tensor_permuted.reshape(tensor_permuted.size(0), -1)
            max_input = l.abs().max(dim=1)[0].clamp(min=1e-6)

            l = weight.reshape(weight.size(0), weight.size(1), -1)
            weight_abs_max = l.abs().max(dim=2)[0].clamp(min=1e-5)

            weight_max = l.max(dim=2)[0].clamp(min=1e-5)
            weight_min = l.min(dim=2)[0].clamp(max=-1e-5)

            weight_out_max, weight_maxin_indx = weight_max.max(dim=1)
            weight_out_min, weight_minin_indx = weight_min.min(dim=1)
            weight_absin_max, _ = weight_abs_max.max(dim=0)
            scales = torch.ones(weight.size(1)).to(weight.device)
            for i in range(weight.size(1)):
                if torch.any(weight_maxin_indx == i) or torch.any(weight_minin_indx == i) : #the channels include weight out max and min
                    continue
                scale = (
                    (max_input[i].pow(self.alpha) / weight_absin_max[i].pow(1-self.alpha))
                    .clamp(min=1e-5)
                )
                scale_max_limit, _ = (weight_out_max / weight_max[:,i]).min(dim=0)
                scale_min_limit, _ = (weight_out_min / weight_min[:,i]).min(dim=0)
                scales[i], _ = torch.stack([scale, scale_max_limit, scale_min_limit]).min(dim=0)
            scales = torch.unsqueeze(scales, dim=0)
            scales = torch.unsqueeze(scales, dim=-1)
            return scales
        elif self.module_type == 'linear':
            if input.dim() == 3:
                tensor_permuted = input.permute(2, 0, 1)
            else:
                tensor_permuted = input.permute(1, 0)
            l = tensor_permuted.reshape(tensor_permuted.size(0), -1)
            max_input = l.abs().max(dim=1)[0].clamp(min=1e-6)

            weight_abs_max = weight.abs().max(dim=1)[0].clamp(min=1e-5)

            weight_out_max, weight_maxin_indx = weight.max(dim=1)
            weight_out_min, weight_minin_indx = weight.min(dim=1)
            weight_out_max = weight_out_max.clamp(min=1e-5)
            weight_out_min = weight_out_min.clamp(max=-1e-5)
            weight_absin_max, _ = weight.abs().max(dim=0)
            scales = torch.ones(weight.size(1)).to(weight.device)
            for i in range(weight.size(1)):
                if torch.any(weight_maxin_indx == i) or torch.any(weight_minin_indx == i) : #the channels include weight out max and min
                    continue
                scale = (
                    (max_input[i].pow(self.alpha) / weight_absin_max[i].pow(1-self.alpha))
                    .clamp(min=1e-5)
                )
                scale_max_limit, _ = (weight_out_max / weight[:,i].clamp(min=1e-5)).min(dim=0)
                scale_min_limit, _ = (weight_out_min / weight[:,i].clamp(max=-1e-5)).min(dim=0)
                scales[i], _ = torch.stack([scale, scale_max_limit, scale_min_limit]).min(dim=0)
            if input.dim() == 3:
                scales = torch.unsqueeze(scales, dim=0)
                scales = torch.unsqueeze(scales, dim=0)
            else:
                scales = torch.unsqueeze(scales, dim=0)
            return scales
        else:
            raise ValueError(
                "Unexpect quantization module"
            )

    def init_smooth_scale(self, input: torch.Tensor, weight: torch.Tensor, split: int):
        if split:
            if self.smooth_type == 'smoothquant':
                '''smoothquant'''
                scales = self.calculate_scales_1(input[:, split:, :, :], weight[:, split:, :, :])
                scales1 = self.calculate_scales_1(input[:, :split, :, :], weight[:, :split, :, :])
                scales = torch.cat([scales, scales1], dim=1)
            elif self.smooth_type == 'DGQ':
                '''precential smooth'''
                scales = self.calculate_scales_2(input[:, split:, :, :], threshold=0.005)
                scales1 = self.calculate_scales_2(input[:, :split, :, :], threshold=0.005)
                scales = torch.cat([scales, scales1], dim=1)
            elif self.smooth_type == 'AWQ':
                '''AWQ smooth'''
                scales = self.calculate_scales_3(input[:, split:, :, :], weight[:, split:, :, :], threshold=0.01)
                scales1 = self.calculate_scales_3(input[:, :split, :, :], weight[:, :split, :, :], threshold=0.01)
                scales = torch.cat([scales, scales1], dim=1)
            elif self.smooth_type == 'weight-aware':
                '''weight aware smooth'''
                scales = self.calculate_scales_4(input[:, split:, :, :], weight[:, split:, :, :])
                scales1 = self.calculate_scales_4(input[:, :split, :, :], weight[:, :split, :, :])
                scales = torch.cat([scales, scales1], dim=1)
            else:
                raise ValueError("Unexpect smooth_type")
        else:
            if self.smooth_type == 'smoothquant':
                '''smoothquant'''
                scales = self.calculate_scales_1(input, weight)
            elif self.smooth_type == 'DGQ':
                '''precential smooth'''
                scales = self.calculate_scales_2(input, threshold=0.005)
            elif self.smooth_type == 'AWQ':
                '''AWQ smooth'''
                scales = self.calculate_scales_3(input, weight, threshold=0.01)
            elif self.smooth_type == 'weight-aware':
                '''weight aware smooth'''
                scales = self.calculate_scales_4(input, weight)
            else:
                raise ValueError("Unexpect smooth_type")
        scales = self.update_scales(scales)
        return scales

    def forward(self, input: torch.Tensor, weight: torch.Tensor, split: int):
        if self.inited is False:
            self.scales = torch.nn.Parameter(self.init_smooth_scale(input=input, weight=weight, split=split))

    def set_inited(self, inited: bool = True):  # inited manually
        self.inited = inited

    def set_time(self, time: int = 0):  # inited manually
        self.time = time

class QuantModule(nn.Module):
    """
    Quantized Module that can perform quantized convolution or normal convolution.
    To activate quantization, please use set_quant_state function.
    """
    def __init__(self, org_module: Union[nn.Conv2d, nn.Linear, nn.Conv1d], weight_quant_params: dict = {},
                 act_quant_params: dict = {}, disable_act_quant: bool = False, act_quant_mode: str = 'qdiff'):
        super(QuantModule, self).__init__()
        # self.module = org_module
        self.weight_quant_params = weight_quant_params
        self.act_quant_params = act_quant_params
        self.stop_train = False
        if isinstance(org_module, nn.Conv2d):
            self.fwd_kwargs = dict(stride=org_module.stride, padding=org_module.padding,
                                   dilation=org_module.dilation, groups=org_module.groups)
            self.fwd_func = F.conv2d
            self.module_str = 'conv2d'
            self.weight = org_module.weight
            self.org_weight = org_module.weight.data.clone()
            if org_module.bias is not None:
                self.bias = org_module.bias
                self.org_bias = org_module.bias.data.clone()
            else:
                self.bias = None
                self.org_bias = None
        elif isinstance(org_module, nn.Conv1d):
            self.fwd_kwargs = dict(stride=org_module.stride, padding=org_module.padding,
                                   dilation=org_module.dilation, groups=org_module.groups)
            self.fwd_func = F.conv1d
            self.module_str = 'conv1d'
            self.weight = org_module.weight
            self.org_weight = org_module.weight.data.clone()
            if org_module.bias is not None:
                self.bias = org_module.bias
                self.org_bias = org_module.bias.data.clone()
            else:
                self.bias = None
                self.org_bias = None
        elif isinstance(org_module, nn.Linear):
            self.fwd_kwargs = dict()
            self.fwd_func = F.linear
            self.module_str = 'linear'
            self.weight = org_module.weight
            self.org_weight = org_module.weight.data.clone()
            if org_module.bias is not None:
                self.bias = org_module.bias
                self.org_bias = org_module.bias.data.clone()
            else:
                self.bias = None
                self.org_bias = None
        else:
            raise ValueError(
                "Unexpect quantization module"
            )

        # de-activate the quantized forward default
        self.use_weight_quant = False
        self.use_act_quant = False
        self.can_recon = True
        self.act_quant_mode = act_quant_mode
        self.disable_act_quant = disable_act_quant
        # initialize quantizer
        self.weight_quantizer = UniformAffineQuantizer(**self.weight_quant_params)
        if self.act_quant_mode == 'qdiff':
            self.act_quantizer = UniformAffineQuantizer(**self.act_quant_params)
        self.split = 0

        self.smooth_quantizer = SmoothQuantizer(module_type=self.module_str, alpha=0.50, timesteps=act_quant_params['num_timesteps'], \
            smooth_type=act_quant_params['smooth_type'])

        self.activation_function = StraightThrough()
        self.ignore_reconstruction = False
        self.extra_repr = org_module.extra_repr
        self.smooth_is_training = False

    def forward(self, input: torch.Tensor, split: int = 0):
        if split != 0 and self.split != 0:
            assert(split == self.split)
        elif split != 0:
            logger.info(f"split at {split}!")
            self.split = split
            self.set_split()
        train_weight = self.weight

        if self.smooth_quantizer.inited == False:
            self.smooth_quantizer(input, self.org_weight, self.split)
        elif self.smooth_quantizer.inited == True and self.smooth_quantizer.scales is not None:
            if self.smooth_is_training:
                input = input / self.smooth_quantizer.scales
                train_weight = self.org_weight.data * self.smooth_quantizer.scales
            else:
                input = input / self.smooth_quantizer.scales

        if not self.disable_act_quant and self.use_act_quant:
            if self.split != 0:
                if self.act_quant_mode == 'qdiff':
                    input_0 = self.act_quantizer(input[:, :self.split, ...])
                    input_1 = self.act_quantizer_0(input[:, self.split:, ...])
                input = torch.cat([input_0, input_1], dim=1)
            else:
                if self.act_quant_mode == 'qdiff':
                    input = self.act_quantizer(input)

        if self.use_weight_quant:
            if self.split != 0:
                weight_0 = self.weight_quantizer(train_weight[:, :self.split, ...])
                weight_1 = self.weight_quantizer_0(train_weight[:, self.split:, ...])
                weight = torch.cat([weight_0, weight_1], dim=1)
            else:
                weight = self.weight_quantizer(train_weight)
            bias = self.bias
        else:
            if self.smooth_is_training:
                weight = train_weight
            else:
                weight = self.org_weight
            bias = self.org_bias

        out = self.fwd_func(input, weight, bias, **self.fwd_kwargs)
        out = self.activation_function(out)
        return out

    def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False):
        self.use_weight_quant = weight_quant
        self.use_act_quant = act_quant

    def set_t(self, t):
        for name, module in self.named_modules():
            if isinstance(module, UniformAffineQuantizer):
                if module.leaf_param:
                    module.set_t(t)

    def set_smooth_weight(self):
        if self.smooth_quantizer.scales is not None:
            if self.smooth_quantizer.scales.dim() == self.weight.dim():
                self.weight.data = self.weight.data * self.smooth_quantizer.scales
                self.org_weight = self.org_weight * self.smooth_quantizer.scales
            elif self.smooth_quantizer.scales.dim() == 3 and self.weight.dim() == 2: #for the attention linear on imagenet and stablediffusion
                scales = self.smooth_quantizer.scales.squeeze(dim=1)
                self.weight.data = self.weight.data * scales
                self.org_weight = self.org_weight * scales
            else:
                raise ValueError(
                    "Unexpect smooth weight"
                )

    def set_smooth_training(self, smooth_is_training: bool = False):
        self.smooth_is_training = smooth_is_training

    def set_split(self):
        self.weight_quantizer_0 = UniformAffineQuantizer(**self.weight_quant_params)
        if self.act_quant_mode == 'qdiff':
            self.act_quantizer_0 = UniformAffineQuantizer(**self.act_quant_params)
