import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from quantizers.uniform import *


class MinMaxQuantLinear(nn.Linear):
    def __init__(self, 
                 in_features: int, 
                 out_features: int,
                 bias: bool = True,
                 mode = "raw",
                 w_bit = 8,
                 a_bit = 8):
        super().__init__(in_features, out_features, bias)
        self.mode = mode
        self.w_quantizer = UniformQuantizer(n_bits = w_bit, symmetric = True, channel_wise = False)
        self.a_quantizer = UniformQuantizer(n_bits = a_bit, symmetric = True, channel_wise = False)
        self.raw_input = None
        self.raw_out = None
        self.tmp_input = None
        self.tmp_out = None
        self.calibrated = False
        self.rect_p = None
        self.p_init = None  
        self.search_mode = False
        self.is_reparam_layer = False
        self.register_buffer('input_scaling_factor', None)
        
    def forward(self, x):
        rect_p = getattr(self, 'rect_p', None)
        if rect_p is not None:
            gamma = torch.exp(rect_p)
            view_shape = [1] * (x.dim() - 1) + [-1]
            x = x / gamma.view(*view_shape)
            if self.search_mode and hasattr(self, 'run_minibatch_ptqsl'):
                self.run_minibatch_ptqsl(x)

            w_rect = self.weight * gamma.view(1, -1)
            w_sim, bias_sim = self.quant_weight_bias(w_override=w_rect)
            x_sim = self.quant_input(x)
            
            return F.linear(x_sim, w_sim, bias_sim)

        scaling_factor = self.input_scaling_factor
        if scaling_factor is not None:
            view_shape = [1] * (x.dim() - 1) + [-1]
            x = x / scaling_factor.view(*view_shape)

        if self.mode == 'raw':
            out = F.linear(x, self.weight, self.bias)
        elif self.mode == "quant_forward":
            out = self.quant_forward(x)
        elif self.mode == 'debug_only_quant_weight':
            out = self.debug_only_quant_weight(x)
        elif self.mode == 'debug_only_quant_act':
            out = self.debug_only_quant_act(x)
        elif self.mode == 'calib': 
             raw_input = getattr(self, 'raw_input', None)
             if raw_input is None:
                self.raw_input = x.cpu()
             else:
                self.raw_input = torch.cat((raw_input.cpu(), x.cpu()), dim=0)
             out = F.linear(x, self.weight, self.bias)
        else:
            raise NotImplementedError
        return out

    def quant_weight_bias(self, w_override=None):
        w = w_override if w_override is not None else self.weight
        w_sim = self.w_quantizer(w)
        return w_sim, self.bias if self.bias is not None else None

    def quant_input(self, x):
        return self.a_quantizer(x)
    
    def quant_forward(self,x):
        w_sim, bias_sim = self.quant_weight_bias()
        x_sim = self.quant_input(x)
        out = F.linear(x_sim, w_sim, bias_sim)
        return out
    
    def debug_only_quant_weight(self, x):
        w_sim, bias_sim = self.quant_weight_bias()
        out = F.linear(x, w_sim, bias_sim)
        return out
    
    def debug_only_quant_act(self, x):
        x_sim = self.quant_input(x)
        out = F.linear(x_sim, self.weight, self.bias)
        return out


class PTQSLQuantLinear(MinMaxQuantLinear):
    """
    PTQSL on linear modules.
    """
    def __init__(self, 
                 in_features: int,
                 out_features: int,
                 bias: bool = True,
                 mode = "raw",
                 w_bit = 8,
                 a_bit = 8,
                 metric = "mse", 
                 search_round = 1, 
                 eq_n = 100, 
                 n_V = 1, 
                 token_channel_wise=False):
        super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit)
        self.w_quantizer = UniformQuantizer(n_bits = w_bit, symmetric = True, channel_wise = True)
        self.a_quantizer = UniformQuantizer(n_bits = a_bit, symmetric = True, channel_wise = False)
        self.metric = metric
        self.search_round = search_round
        self.eq_n = eq_n
        self.parallel_eq_n = eq_n
        self.n_V = n_V
        self.crb_rows = out_features // n_V
        self.token_channel_wise = token_channel_wise
        
        self.w_quantizer.scale = nn.Parameter(torch.zeros((n_V, self.crb_rows, 1)))
        self.a_quantizer.scale = nn.Parameter(torch.zeros((1)))

    def _get_similarity(self, tensor_raw, tensor_sim, metric=None, raw_grad=None):
        if metric == "mae":
            similarity = -torch.abs(tensor_raw - tensor_sim)
        elif metric == "mse":
            similarity = -(tensor_raw - tensor_sim) ** 2
        else:
            raise NotImplementedError(f"metric {metric} not implemented!")
        return similarity
    
    def quant_weight_bias(self, w_override=None):
        w = w_override if w_override is not None else self.weight
        w_reshaped = w.view(self.n_V, self.crb_rows, self.in_features)
        w_sim = self.w_quantizer(w_reshaped)
        w_sim = w_sim.view(self.out_features, self.in_features)
        if self.bias is not None:
            b_sim = self.bias
            return w_sim, b_sim
        else:
            return w_sim, None


class PTQSLBatchingQuantLinear(PTQSLQuantLinear):
    def __init__(self, 
                 in_features: int,
                 out_features: int,
                 bias: bool = True,
                 mode = "raw",
                 w_bit = 8,
                 a_bit = 8,
                 metric = "mse", 
                 calib_batch_size = 32,
                 search_round = 1, 
                 eq_n = 100, 
                 n_V = 1, 
                 token_channel_wise=False):
        super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit,
                         metric=metric, search_round=search_round, eq_n=eq_n, n_V=n_V, token_channel_wise=token_channel_wise)
        self.calib_batch_size = calib_batch_size

    def _initialize_calib_parameters(self):
        """ 
        set parameters for feeding calibration data
        """
        self.calib_size = self.raw_input.shape[0]
        if torch.cuda.is_available():
            props = torch.cuda.get_device_properties(0)
            memory = props.total_memory // 2
        else:
            raise EnvironmentError("CUDA is not available on this system")
        numel = (16 * self.raw_input[:self.calib_batch_size].numel() + 
                 16 * self.raw_out[:self.calib_batch_size].numel()) # number of parameters on GPU
        self.parallel_eq_n = int((memory / 4) // numel)
        self.parallel_eq_n = math.ceil(self.eq_n * 1.0 / math.ceil(self.eq_n * 1.0 / self.parallel_eq_n))
    
    def _initialize_weight_scale(self):
        self.w_quantizer.scale.data.copy_(
            self.weight.view(self.n_V, self.crb_rows, self.in_features).abs().amax([2],keepdim=True) / 
            (self.w_quantizer.n_levels - 0.5)
        )
        self.w_quantizer.inited = True

    def _initialize_activation_scale(self):
        tmp_a_scales = []
        for b_st in range(0, self.raw_input.shape[0], self.calib_batch_size):
            b_ed = min(self.raw_input.shape[0], b_st + self.calib_batch_size)
            x_ = self.raw_input[b_st:b_ed].cuda()
            a_scale_ = (x_.abs().max() / (self.a_quantizer.n_levels - 0.5)).detach().view(1, 1)
            tmp_a_scales.append(a_scale_)
        tmp_a_scale = torch.cat(tmp_a_scales, dim=0).amax(dim=0, keepdim=False).view(-1)
        self.a_quantizer.scale.data.copy_(tmp_a_scale)
        self.a_quantizer.inited = True

    def _search_best_w_scale(self, weight_scale_candidates):
        batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax)
        for b_st in range(0, self.calib_size, self.calib_batch_size):
            b_ed = min(self.calib_size, b_st + self.calib_batch_size)
            x = self.raw_input[b_st:b_ed].cuda()
            raw_out_expanded = self.raw_out[b_st:b_ed].cuda().unsqueeze(-2) # shape: b,*,1,out_features
            raw_out_expanded = raw_out_expanded.view(*raw_out_expanded.shape[:-1], self.n_V, -1) # shape: b,*,1,n_V,crb_rows
            similarities = []
            for p_st in range(0, self.eq_n, self.parallel_eq_n):
                p_ed = min(self.eq_n, p_st + self.parallel_eq_n)
                cur_w_scale = weight_scale_candidates[p_st:p_ed]
                # quantize weight and bias 
                w_sim = self.weight.view(self.n_V, self.crb_rows, self.in_features).unsqueeze(0) # shape: 1,n_V,crb_rows,in_features
                w_sim = (w_sim / cur_w_scale).round_().clamp_(
                    -self.w_quantizer.n_levels, self.w_quantizer.n_levels - 1
                ).mul_(cur_w_scale) # shape: parallel_eq_n,n_V,crb_rows,in_features
                w_sim = w_sim.view(-1, self.in_features) # shape: parallel_eq_n*out_features,in_features
                bias_sim = self.bias.repeat(p_ed - p_st) if self.bias is not None else None
                x_sim = self.quant_input(x)
                out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: b,*,parallel_eq_n*out_features
                out_sim = out_sim.view(*out_sim.shape[:-1], p_ed-p_st, self.n_V, -1) # shape: b,*,parallel_eq_n,n_V,crb_rows
                similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric) # shape: b,*,parallel_eq_n,n_V,crb_rows
                if len(similarity.shape) > 4:
                    similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-3))) # shape: b,parallel_eq_n,n_V,crb_rows
                similarity = similarity.sum(dim=0, keepdim=True) # shape: (1, parallel_eq_n, n_V) or (1, parallel_eq_n, n_V, crb_rows)
                similarities.append(similarity)
            similarities = torch.cat(similarities, dim=1) # shape: (1, eq_n, n_V) or (1, eq_n, n_V, crb_rows)
            batch_similarities.append(similarities)
        batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) # shape: (eq_n, n_V) or (eq_n, n_V, crb_rows)
        best_index = batch_similarities.argmax(dim=0).reshape(1, self.n_V, -1, 1) # shape: (1,n_V,1,1) or (1,n_V,crb_rows,1)
        tmp_w_scale = torch.gather(weight_scale_candidates, dim=0, index=reshaped_best_index) # shape: (1,n_V*crb_rows,1)
        self.w_quantizer.scale.data.copy_(tmp_w_scale.squeeze(0))
        return best_index.squeeze(0) # shape: (n_V,crb_rows,1)

    def _search_best_a_scale(self, input_scale_candidates):
        batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax)
        for b_st in range(0, self.calib_size, self.calib_batch_size):
            b_ed = min(self.calib_size, b_st + self.calib_batch_size)
            x = self.raw_input[b_st:b_ed].cuda()
            raw_out_expanded = self.raw_out[b_st:b_ed].cuda().unsqueeze(-2) # shape: B,*,1,oc
            similarities = []
            for p_st in range(0,self.eq_n,self.parallel_eq_n):
                p_ed = min(self.eq_n, p_st+self.parallel_eq_n)
                cur_a_scale = input_scale_candidates[:, p_st:p_ed]
                # quantize weight and bias 
                w_sim, bias_sim = self.quant_weight_bias()
                # quantize input
                x_sim = x.unsqueeze(-1) # shape: b,*,in_features,1
                x_sim = (x_sim / cur_a_scale).round_().clamp_(
                    -self.a_quantizer.n_levels, self.a_quantizer.n_levels - 1
                ).mul_(cur_a_scale) # shape: B,*,in_features,parallel_eq_n
                x_sim = x_sim.permute(*list(range(len(x_sim.shape)-2)),-1,-2) # shape: B,*,parallel_eq_n,in_features
                # calculate similarity and store them
                out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: B,*,parallel_eq_n,out_features
                similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric) # shape: B,*,parallel_eq_n,out_features
                similarity = torch.mean(similarity, dim=-1) # shape: B,*,parallel_eq_n
                if len(similarity.shape) > 2:
                    similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-1))) # shape: b, parallel_eq_n
                similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n
                similarities.append(similarity)
            # store best input scale and store in tmp_a_scale
            similarities = torch.cat(similarities, dim=1) # shape: 1, eq_n
            batch_similarities.append(similarities)
        batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) # shape: eq_n
        best_index = batch_similarities.argmax(dim=0, keepdim=True).reshape(1, -1)
        tmp_a_scale = torch.gather(input_scale_candidates, dim=-1, index=best_index)
        self.a_quantizer.scale.data.copy_(tmp_a_scale.squeeze(-1))
        return best_index.squeeze(0)

    def hyperparameter_searching(self):
        self._initialize_calib_parameters()
        self._initialize_weight_scale()
        self._initialize_activation_scale()

        # prepare weight scales and similarities
        self.eq_alpha, self.eq_beta = 0.01, 1.2
        weight_scale_candidates = torch.tensor(
            [self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]
        ).cuda().view(-1,1,1,1) * self.w_quantizer.scale.unsqueeze(0) # shape: eq_n,n_V,1,1
        input_scale_candidates =  torch.tensor(
            [self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]
        ).cuda().view(1,-1) * self.a_quantizer.scale.unsqueeze(-1) # shape: (1,eq_n) or (in_features,eq_n)
            
        for e in range(self.search_round):
            # search for best weight scale
            self._search_best_w_scale(weight_scale_candidates)
            # search for best input scale
            if self.a_quantizer.n_bits < 32:
                self._search_best_a_scale(input_scale_candidates)
            else:
                break

        self.calibrated = True
        del self.raw_input, self.raw_out
        return None
        
        
class AsymmetricallyBatchingQuantLinear(PTQSLBatchingQuantLinear):
    def __init__(self, 
                 in_features: int,
                 out_features: int,
                 bias: bool = True,
                 mode = "raw",
                 w_bit = 8,
                 a_bit = 8,
                 metric = "mse", 
                 calib_batch_size = 32,
                 search_round = 1, 
                 eq_n = 100, 
                 n_V = 1, 
                 token_channel_wise=False):
        super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit,
                         metric=metric, calib_batch_size=calib_batch_size, search_round=search_round, 
                         eq_n=eq_n, n_V=n_V, token_channel_wise=token_channel_wise)
        
        del self.a_quantizer, self.w_quantizer
        self.w_quantizer = UniformQuantizer(n_bits = w_bit, symmetric = False, channel_wise = True)
        self.a_quantizer = UniformQuantizer(n_bits = a_bit, symmetric = False, channel_wise = False)
        self.a_quantizer.scale = nn.Parameter(torch.zeros((1)))
        self.a_quantizer.zero_point = nn.Parameter(torch.zeros((1)))
        self.w_quantizer.scale = nn.Parameter(torch.zeros((n_V, self.crb_rows, 1)))
        self.w_quantizer.zero_point = nn.Parameter(torch.zeros((n_V, self.crb_rows, 1)))
        self.register_buffer('input_scaling_factor', None)

    def _initialize_weight_scale(self):
        self.w_quantizer.scale.data.copy_(
            (self.weight.view(self.n_V, self.crb_rows, self.in_features).amax([2],keepdim=True) - 
                self.weight.view(self.n_V, self.crb_rows, self.in_features).amin([2],keepdim=True)) / 
            (2 * self.w_quantizer.n_levels - 1)
        )
        self.w_quantizer.zero_point.data.copy_(
            -self.weight.view(self.n_V, self.crb_rows, self.in_features).amin([2],keepdim=True) / self.w_quantizer.scale
        )
        self.w_quantizer.inited = True

    def _initialize_activation_scale(self):
        tmp_a_scales = []
        tmp_a_max, tmp_a_min = [], []
        for b_st in range(0, self.raw_input.shape[0], self.calib_batch_size):
            b_ed = min(self.raw_input.shape[0], b_st + self.calib_batch_size)
            x_ = self.raw_input[b_st:b_ed].cuda()
            if self.a_quantizer.channel_wise:
                a_max = x_.abs().amax([i for i in range(x_.ndim-1)], keepdim=False).detach().view(1, -1)
                a_min = x_.abs().amin([i for i in range(x_.ndim-1)], keepdim=False).detach().view(1, -1)
            else:
                a_max = x_.abs().max().detach().view(1, 1)
                a_min = x_.abs().min().detach().view(1, 1)
            tmp_a_max.append(a_max)
            tmp_a_min.append(a_min)
        tmp_a_max = torch.cat(tmp_a_max, dim=0).amax(dim=0, keepdim=False)
        tmp_a_min = torch.cat(tmp_a_min, dim=0).amin(dim=0, keepdim=False)
        self.a_quantizer.scale.data.copy_((tmp_a_max - tmp_a_min) / (2 * self.a_quantizer.n_levels - 1))
        self.a_quantizer.zero_point.data.copy_(-tmp_a_min / self.a_quantizer.scale)
        self.a_quantizer.inited = True

    def _search_best_w_scale_self(self, weight_scale_candidates, weight_zero_point_candidates, topk=1):
        similarities = []
        raw_weight = self.weight.view(self.n_V, self.crb_rows, self.in_features).unsqueeze(0) # shape: 1,n_V,crb_rows,in_features
        for p_st in range(0, self.eq_n, self.parallel_eq_n):
            p_ed = min(self.eq_n, p_st + self.parallel_eq_n)
            cur_w_scale = weight_scale_candidates[p_st:p_ed]
            cur_w_zero_point = weight_zero_point_candidates[p_st:p_ed]
            # quantize weight and bias 
            w_quant = ((raw_weight / cur_w_scale).round_() + cur_w_zero_point).clamp(0, 2 * self.w_quantizer.n_levels - 1)
            w_dequant = (w_quant - cur_w_zero_point) * cur_w_scale # shape: parallel_eq_n,n_V,crb_rows,in_features
            similarity = self._get_similarity(raw_weight, w_dequant, 'mse') # shape: parallel_eq_n,n_V,crb_rows,in_features
            similarity = torch.mean(similarity, dim=-1, keepdim=False) # shape: parallel_eq_n,n_V,crb_rows
            similarities.append(similarity)
        similarities = torch.cat(similarities, dim=0) # shape: eq_n,n_V,crb_rows
        _, best_index = torch.topk(similarities, k=topk, dim=0)
        best_index = best_index.reshape(topk, self.n_V, -1, 1)
        if topk == 1:
            tmp_w_scale = torch.gather(weight_scale_candidates, dim=0, index=best_index)
            tmp_w_zero_point = torch.gather(weight_zero_point_candidates, dim=0, index=best_index)
            self.w_quantizer.scale.data.copy_(tmp_w_scale.squeeze(0))
            self.w_quantizer.zero_point.data.copy_(tmp_w_zero_point.squeeze(0))
            self.w_quantizer.inited = True
        return best_index.squeeze(0) # shape: (topk, n_V,crb_rows,1)

    def _search_best_a_scale_self(self, input_scale_candidates, input_zero_point_candidates, topk=1):
        batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax)
        for b_st in range(0, self.calib_size, self.calib_batch_size):
            b_ed = min(self.calib_size, b_st + self.calib_batch_size)
            x = self.raw_input[b_st:b_ed].cuda()
            raw_x = self.raw_input[b_st:b_ed].cuda().unsqueeze(-1) # shape: b,*,in_features,1
            similarities = []
            for p_st in range(0,self.eq_n,self.parallel_eq_n):
                p_ed = min(self.eq_n, p_st+self.parallel_eq_n)
                cur_a_scale = input_scale_candidates[:, p_st:p_ed]
                cur_a_zero_point = input_zero_point_candidates[:, p_st:p_ed]
                # quantize input
                x_sim = x.unsqueeze(-1) # shape: B,*,in_features,1
                x_quant = ((x_sim / cur_a_scale).round_() + cur_a_zero_point).clamp_(0, 2 * self.a_quantizer.n_levels - 1) # shape: B,*,in_features,parallel_eq_n
                x_dequant = (x_quant - cur_a_zero_point) * cur_a_scale # shape: B,*,in_features,parallel_eq_n
                similarity = self._get_similarity(raw_x, x_dequant, 'mse') # shape: b,*,in_features,parallel_eq_n
                if len(similarity.shape) > 3:
                    similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-2))) # shape: b, in_features, parallel_eq_n
                if not self.a_quantizer.channel_wise:
                    similarity = torch.mean(similarity, dim=1, keepdim=True) # shape: b, 1, parallel_eq_n
                similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, in_features, parallel_eq_n
                similarities.append(similarity)
            # store best input scale and store in tmp_a_scale
            similarities = torch.cat(similarities, dim=-1) # shape: 1, in_features, eq_n
            batch_similarities.append(similarities)
        batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) # shape: in_features, eq_n
        _, best_index = torch.topk(batch_similarities, k=topk, dim=-1) # shape: in_features, topk
        if topk == 1:
            tmp_a_scale = torch.gather(input_scale_candidates, dim=-1, index=best_index)
            tmp_a_zero_point = torch.gather(input_zero_point_candidates, dim=-1, index=best_index)
            self.a_quantizer.scale.data.copy_(tmp_a_scale.squeeze(-1))
            self.a_quantizer.zero_point.data.copy_(tmp_a_zero_point.squeeze(-1))
            self.a_quantizer.inited = True
        return best_index
    
    def _search_best_w_scale(self, weight_scale_candidates, weight_zero_point_candidates, topk=1):
        batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax)
        for b_st in range(0, self.calib_size, self.calib_batch_size):
            b_ed = min(self.calib_size, b_st + self.calib_batch_size)
            x = self.raw_input[b_st:b_ed].cuda()
            raw_out_expanded = self.raw_out[b_st:b_ed].cuda().unsqueeze(-2) # shape: b,*,1,out_features
            raw_out_expanded = raw_out_expanded.view(*raw_out_expanded.shape[:-1], self.n_V, -1) # shape: b,*,1,n_V,crb_rows
            similarities = []
            for p_st in range(0, self.eq_n, self.parallel_eq_n):
                p_ed = min(self.eq_n, p_st+self.parallel_eq_n)
                cur_w_scale = weight_scale_candidates[p_st:p_ed]
                cur_w_zero_point = weight_zero_point_candidates[p_st:p_ed]
                # quantize weight and bias 
                w_sim = self.weight.view(self.n_V, self.crb_rows, self.in_features).unsqueeze(0) # shape: 1,n_V,crb_rows,in_features
                w_quant = ((w_sim / cur_w_scale).round_() + cur_w_zero_point).clamp(0, 2 * self.w_quantizer.n_levels - 1)
                w_dequant = (w_quant - cur_w_zero_point) * cur_w_scale # shape: parallel_eq_n,n_V,crb_rows,in_features
                w_sim = w_dequant.view(-1,self.in_features) # shape: parallel_eq_n*out_features,in_features
                bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None
                x_sim = self.quant_input(x)
                out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: B,*,parallel_eq_n*out_features
                out_sim = out_sim.view(*out_sim.shape[:-1], p_ed-p_st, self.n_V, -1) # shape: b,*,parallel_eq_n,n_V,crb_rows
                similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric) # shape: b,*,parallel_eq_n,n_V,crb_rows
                if len(similarity.shape) > 4:
                    similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-3))) # shape: b, parallel_eq_n, n_V, crb_rows
                similarity = similarity.sum(dim=0, keepdim=True) # shape: (1, parallel_eq_n, n_V) or (1, parallel_eq_n, n_V, crb_rows)
                similarities.append(similarity)
            # store best weight scale of h into tmp_w_scale
            similarities = torch.cat(similarities, dim=1) # shape: (1, eq_n, n_V) or (1, eq_n, n_V, crb_rows)
            batch_similarities.append(similarities)
        batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) # shape: (eq_n, n_V) or (eq_n, n_V, crb_rows)
        _, best_index = torch.topk(batch_similarities, k=topk, dim=0)
        best_index = best_index.reshape(topk, self.n_V, -1, 1)
        if topk == 1:
            tmp_w_scale = torch.gather(weight_scale_candidates, dim=0, index=best_index)
            tmp_w_zero_point = torch.gather(weight_zero_point_candidates, dim=0, index=best_index)
            self.w_quantizer.scale.data.copy_(tmp_w_scale.squeeze(0))
            self.w_quantizer.zero_point.data.copy_(tmp_w_zero_point.squeeze(0))
        return best_index.squeeze(0) # shape: (topk, n_V,crb_rows,1)
    
    def _search_best_a_scale(self, input_scale_candidates, input_zero_point_candidates, topk=1):
        batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax)
        for b_st in range(0, self.calib_size, self.calib_batch_size):
            b_ed = min(self.calib_size, b_st + self.calib_batch_size)
            x = self.raw_input[b_st:b_ed].cuda()
            raw_out_expanded = self.raw_out[b_st:b_ed].cuda().unsqueeze(-2) # shape: b,*,1,oc
            similarities = []
            for p_st in range(0,self.eq_n,self.parallel_eq_n):
                p_ed = min(self.eq_n, p_st+self.parallel_eq_n)
                cur_a_scale = input_scale_candidates[:, p_st:p_ed]
                cur_a_zero_point = input_zero_point_candidates[:, p_st:p_ed]
                # quantize weight and bias 
                w_sim, bias_sim = self.quant_weight_bias()
                # quantize input
                x_sim = x.unsqueeze(-1) # shape: B,*,in_features,1
                x_quant = ((x_sim / cur_a_scale).round_() + cur_a_zero_point).clamp_(0, 2 * self.a_quantizer.n_levels - 1) # shape: B,*,in_features,parallel_eq_n
                x_dequant = (x_quant - cur_a_zero_point) * cur_a_scale # shape: B,*,in_features,parallel_eq_n
                x_sim = x_dequant.permute(*list(range(len(x_sim.shape)-2)),-1,-2) # shape: B,*,parallel_eq_n,in_features
                # calculate similarity and store them
                out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: b,*,parallel_eq_n,out_features
                similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric) # shape: b,*,parallel_eq_n,out_features
                similarity = torch.mean(similarity, dim=-1) # shape: B,*,parallel_eq_n
                if len(similarity.shape) > 2:
                    similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-1))) # shape: b, parallel_eq_n
                similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n
                similarities.append(similarity)
            # store best input scale and store in tmp_a_scale
            similarities = torch.cat(similarities, dim=1) # shape: 1, eq_n
            batch_similarities.append(similarities)
        batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=True) # shape: 1, eq_n
        _, best_index = torch.topk(batch_similarities, k=topk, dim=-1) # shape: 1, topk
        if topk == 1:
            tmp_a_scale = torch.gather(input_scale_candidates, dim=-1, index=best_index)
            tmp_a_zero_point = torch.gather(input_zero_point_candidates, dim=-1, index=best_index)
            self.a_quantizer.scale.data.copy_(tmp_a_scale.squeeze(-1))
            self.a_quantizer.zero_point.copy_(tmp_a_zero_point.squeeze(-1))
        return best_index
        
    def calculate_percentile_weight_candidates(self, l=0.9, r=1.0):
        num_zp = min(16, self.w_quantizer.n_levels * 2)
        num_scale = int(self.eq_n / num_zp)
        pct = torch.tensor([l, r])
        w_uppers_candidates = torch.quantile(
            self.weight.view(self.n_V, self.crb_rows, self.in_features), pct.to(self.weight.device), dim=-1
        ).unsqueeze(-1) # shape: 2, n_V, crb_rows, 1
        w_lowers_candidates = torch.quantile(
            self.weight.view(self.n_V, self.crb_rows, self.in_features), (1-pct).to(self.weight.device), dim=-1
        ).unsqueeze(-1) # shapeL 2, n_V, crb_rows, 1
        delta_min = w_uppers_candidates[0:1] - w_lowers_candidates[0:1]
        delta_max = w_uppers_candidates[1:] - w_lowers_candidates[1:]
        splits = torch.linspace(0, 1, steps=num_scale).cuda()[:, None, None, None] * (delta_max - delta_min)
        weight_scale_candidates = (delta_min + splits).repeat(num_zp, 1, 1, 1) / (2 * self.w_quantizer.n_levels - 1)
        zp_min = int(self.w_quantizer.n_levels - num_zp / 2)
        zp_max = int(self.w_quantizer.n_levels + num_zp / 2)
        zp_candidates = torch.tensor(range(zp_min, zp_max)).cuda()
        weight_zero_point_candidates = zp_candidates.repeat_interleave(num_scale)[:, None, None, None]
        weight_zero_point_candidates = weight_zero_point_candidates.repeat(1, self.n_V, self.crb_rows, self.in_features)
        return weight_scale_candidates, weight_zero_point_candidates
    
    def calculate_percentile_activation_candidates(self, l=0.9, r=1.0):
        num_zp = min(16, self.a_quantizer.n_levels * 2)
        num_scale = int(self.eq_n / num_zp)
        percentiles_uppers, percentiles_lowers = [], []
        pct = torch.tensor([l, r])
        x = self.raw_input.cuda()
        tensor_too_large = True
        mini_batch_size = 1
        if self.a_quantizer.channel_wise:
            a_uppers_candidates = torch.quantile(x.view(-1, x.shape[-1]), pct.to(x.device), dim=0).transpose(0, 1) # shape: in_features, 2
            a_lowers_candidates = torch.quantile(x.view(-1, x.shape[-1]), (1-pct).to(x.device), dim=0).transpose(0, 1) # shape: in_features, 2
        else:
            while tensor_too_large:
                try:
                    a_uppers_candidates = torch.quantile(x.view(mini_batch_size, -1), pct.to(x.device), dim=-1).mean(dim=-1).unsqueeze(0) # shape: 1, 2
                    a_lowers_candidates = torch.quantile(x.view(mini_batch_size, -1), (1-pct).to(x.device), dim=-1).mean(dim=-1).unsqueeze(0) # shape: 1, 2
                    tensor_too_large = False
                except:
                    mini_batch_size *= 2
        delta_min = a_uppers_candidates[:, 0:1] - a_lowers_candidates[:, 0:1]
        delta_max = a_uppers_candidates[:, 1:] - a_lowers_candidates[:, 1:]
        splits = torch.linspace(0, 1, steps=num_scale).cuda()[None, :] * (delta_max - delta_min)
        a_scale_candidates = ((delta_min + splits).repeat(1, num_zp) / (2 * self.a_quantizer.n_levels - 1)).clamp(min=1e-4)
        a_scale_candidates = torch.cat([a_scale_candidates, a_scale_candidates[..., -1:]], dim=-1)

        zp_min = int(self.a_quantizer.n_levels - num_zp / 2)
        zp_max = int(self.a_quantizer.n_levels + num_zp / 2)
        zp_candidates = torch.tensor(range(zp_min, zp_max)).cuda()
        a_zero_point_candidates = zp_candidates.repeat_interleave(num_scale)[None, :]
        a_zero_point_candidates = a_zero_point_candidates.repeat(a_scale_candidates.shape[0], 1)
        a_zero_point_candidates = torch.cat([a_zero_point_candidates, a_zero_point_candidates[..., -1:]], dim=-1)
        return a_scale_candidates, a_zero_point_candidates
    
    def hyperparameter_searching(self):
        self._initialize_calib_parameters()

        weight_scale_candidates, weight_zero_point_candidates = self.calculate_percentile_weight_candidates()
        a_scale_candidates, a_zero_point_candidates = self.calculate_percentile_activation_candidates()
        self._search_best_w_scale_self(weight_scale_candidates, weight_zero_point_candidates)
        self._search_best_a_scale_self(a_scale_candidates, a_zero_point_candidates)
        for e in range(self.search_round):
            torch.cuda.empty_cache()
            self._search_best_w_scale(weight_scale_candidates, weight_zero_point_candidates)
            torch.cuda.empty_cache()
            self._search_best_a_scale(a_scale_candidates, a_zero_point_candidates)
        
        if (self.token_channel_wise and len(self.raw_input.shape) == 3):
            B, N, C = self.raw_input.shape
            token_wise_scale = self.a_quantizer.scale.expand(1, N, 1)
            del self.a_quantizer.scale
            self.a_quantizer.scale = nn.Parameter(token_wise_scale.clone())
        
        self.calibrated = True
        del self.raw_input, self.raw_out
        return None
    
    def run_minibatch_ptqsl(self, x):
        with torch.no_grad():
            n_coarse = 16 
            n_fine = 16
            x_flat = x.flatten()
            if x_flat.numel() > 1000000:
                indices = torch.randint(0, x_flat.numel(), (1000000,), device=x.device)
                x_sample = x_flat[indices]
            else:
                x_sample = x_flat

            pct = torch.tensor([0.999, 1.0], device=x.device)
            u_cands = torch.quantile(x_sample, pct, dim=-1).view(1, 2)
            l_cands = torch.quantile(x_sample, 1.0 - pct, dim=-1).view(1, 2)
            delta_min = u_cands[:, 0:1] - l_cands[:, 0:1]
            delta_max = u_cands[:, 1:] - l_cands[:, 1:]
            step_coarse = (delta_max - delta_min) / (n_coarse - 1)
            splits_coarse = torch.linspace(0, 1, steps=n_coarse, device=x.device)[None, :]
            scale_coarse = (delta_min + splits_coarse * (delta_max - delta_min)) / (2 * self.a_quantizer.n_levels - 1)
            scale_coarse = scale_coarse.clamp(min=1e-4)
            num_zp = min(16, self.a_quantizer.n_levels * 2)
            zp_min = int(self.a_quantizer.n_levels - num_zp / 2)
            zp_max = int(self.a_quantizer.n_levels + num_zp / 2)
            zp_cands = torch.arange(zp_min, zp_max, device=x.device)
            s_cands_1 = scale_coarse.view(-1).repeat_interleave(num_zp)
            z_cands_1 = zp_cands.repeat(n_coarse)
            best_s_1, best_z_1, min_mse_1 = self._evaluate_candidates(x, s_cands_1, z_cands_1)
            
            search_range = step_coarse.view(-1) / (2 * self.a_quantizer.n_levels - 1)
            s_min_fine = (best_s_1 - search_range).clamp(min=1e-4)
            s_max_fine = best_s_1 + search_range
            splits_fine = torch.linspace(0, 1, steps=n_fine, device=x.device)
            scale_fine = s_min_fine + splits_fine * (s_max_fine - s_min_fine)
            zp_fine_range = torch.arange(best_z_1 - 1, best_z_1 + 2, device=x.device).clamp(zp_min, zp_max)
            s_cands_2 = scale_fine.repeat_interleave(len(zp_fine_range))
            z_cands_2 = zp_fine_range.repeat(n_fine)
            best_s_2, best_z_2, min_mse_2 = self._evaluate_candidates(x, s_cands_2, z_cands_2)
            
            if min_mse_2 < min_mse_1:
                final_s, final_z = best_s_2, best_z_2
            else:
                final_s, final_z = best_s_1, best_z_1

            self.a_quantizer.scale.data.copy_(final_s)
            self.a_quantizer.zero_point.data.copy_(final_z)

    def _evaluate_candidates(self, x, s_cands, z_cands):
        chunk_size = 8 
        x_expanded = x.unsqueeze(-1)
        all_mses = []
        num_cands = s_cands.size(0)
        
        for i in range(0, num_cands, chunk_size):
            s_chunk = s_cands[i : i + chunk_size]
            z_chunk = z_cands[i : i + chunk_size]
            view_shape = (1,) * x.ndim + (-1,)
            s_chunk = s_chunk.view(view_shape)
            z_chunk = z_chunk.view(view_shape)
            
            x_quant = ((x_expanded / s_chunk).round_() + z_chunk).clamp(0, 2 * self.a_quantizer.n_levels - 1)
            x_dequant = (x_quant - z_chunk) * s_chunk
            mse_chunk = (x_expanded - x_dequant).pow(2).sum(dim=list(range(x.ndim)))
            all_mses.append(mse_chunk)
            
        all_mses = torch.cat(all_mses, dim=0)
        best_idx = all_mses.argmin()
        return s_cands[best_idx], z_cands[best_idx], all_mses[best_idx]
    
    def update_weight_quant_params(self, w_rect):
        with torch.no_grad():
            if hasattr(self, 'n_V'):
                w_reshaped = w_rect.contiguous().view(self.n_V, self.crb_rows, self.in_features)
            else:
                w_reshaped = w_rect

            n_coarse = 16
            n_fine = 16
            pct = torch.tensor([0.999, 1.0], device=w_rect.device)
            w_upper = torch.quantile(w_reshaped, pct, dim=-1).unsqueeze(-1)
            w_lower = torch.quantile(w_reshaped, 1.0 - pct, dim=-1).unsqueeze(-1)
            delta_min = w_upper[0] - w_lower[0]
            delta_max = w_upper[1] - w_lower[1]
            w_min_stat = w_reshaped.amin(dim=-1, keepdim=True).unsqueeze(0)

            splits_coarse = torch.linspace(0, 1, steps=n_coarse, device=w_rect.device)[:, None, None, None]
            s_cands_1 = (delta_min + splits_coarse * (delta_max - delta_min)) / (2 * self.w_quantizer.n_levels - 1)
            s_cands_1 = s_cands_1.clamp(min=1e-6)
            z_cands_1 = (-w_min_stat / s_cands_1).round()
            mse_1 = self._evaluate_weight_mse(w_reshaped, s_cands_1, z_cands_1) 
            best_idx_1 = mse_1.argmin(dim=0)
            gather_idx = best_idx_1.unsqueeze(0).unsqueeze(-1)
            best_s_1 = torch.gather(s_cands_1, 0, gather_idx) 
            best_z_1 = torch.gather(z_cands_1, 0, gather_idx) # [1, n_V, crb_rows, 1]
            coarse_step = (delta_max - delta_min) / (n_coarse - 1) / (2 * self.w_quantizer.n_levels - 1)
            s_min_fine = (best_s_1 - coarse_step).clamp(min=1e-6)
            s_max_fine = best_s_1 + coarse_step
            
            splits_fine = torch.linspace(0, 1, steps=n_fine, device=w_rect.device)[:, None, None, None]
            s_cands_2 = s_min_fine + splits_fine * (s_max_fine - s_min_fine)
            s_cands_2 = s_cands_2.repeat(3, 1, 1, 1) 
            z_ideal = (-w_min_stat / s_cands_2).round()
            offsets = torch.tensor([-1, 0, 1], device=w_rect.device).repeat_interleave(n_fine)[:, None, None, None]
            z_cands_2 = z_ideal + offsets
            mse_2 = self._evaluate_weight_mse(w_reshaped, s_cands_2, z_cands_2)
            best_idx_2 = mse_2.argmin(dim=0)
            gather_idx_2 = best_idx_2.unsqueeze(0).unsqueeze(-1)
            final_s = torch.gather(s_cands_2, 0, gather_idx_2).squeeze(0)
            final_z = torch.gather(z_cands_2, 0, gather_idx_2).squeeze(0)
            self.w_quantizer.scale.data.copy_(final_s)
            self.w_quantizer.zero_point.data.copy_(final_z)
            
    def _evaluate_weight_mse(self, w, s, z):
        w_expanded = w.unsqueeze(0)
        w_int = ((w_expanded / s).round_() + z).clamp(0, 2 * self.w_quantizer.n_levels - 1)
        w_dequant = (w_int - z) * s
        return (w_expanded - w_dequant).pow(2).sum(dim=-1) # [N_cand, n_V, crb_rows]


class AsymmetricallyChannelWiseBatchingQuantLinear(AsymmetricallyBatchingQuantLinear):
    def __init__(self, 
                 in_features: int,
                 out_features: int,
                 bias: bool = True,
                 mode = "raw",
                 w_bit = 8,
                 a_bit = 8,
                 metric = "mse", 
                 calib_batch_size = None,
                 search_round = 1, 
                 eq_n = 100, 
                 n_V=1,
                 token_channel_wise=False):
        super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit,
                         metric=metric, calib_batch_size=calib_batch_size, search_round=search_round, 
                         eq_n=eq_n, n_V=n_V, token_channel_wise=token_channel_wise)
        del self.a_quantizer
        self.a_quantizer = UniformQuantizer(n_bits = a_bit, symmetric = False, channel_wise = True)
        self.a_quantizer.scale = nn.Parameter(torch.zeros((in_features)))
        self.a_quantizer.zero_point = nn.Parameter(torch.zeros((in_features)))
        self._prev_layer = None
    
    def __setattr__(self, name, value):
        if name == "prev_layer":
            self.__dict__['_prev_layer'] = value
        else:
            super().__setattr__(name, value)

    @property
    def prev_layer(self):
        return self._prev_layer

    @prev_layer.setter
    def prev_layer(self, layer):
        self._prev_layer = layer
    
    def hyperparameter_searching(self):
        assert self.a_quantizer.channel_wise and self.w_quantizer.channel_wise
        self._initialize_calib_parameters()
        a_scale_candidates, a_zero_point_candidates = self.calculate_percentile_activation_candidates()
        self._search_best_a_scale_self(a_scale_candidates, a_zero_point_candidates)
        self.calibrated = True
        
    def reparam_step1(self):
        self.calibrated = False
        channel_min = -self.a_quantizer.zero_point * self.a_quantizer.scale
        target_channel_scale = torch.mean(self.a_quantizer.scale).view(1)
        target_channel_zero_point = torch.mean(self.a_quantizer.zero_point).round().view(1)
        target_channel_min = -target_channel_zero_point * target_channel_scale
        r = (self.a_quantizer.scale / target_channel_scale)
        b = channel_min / r - target_channel_min
        self.prev_layer.weight.data = self.prev_layer.weight.data / r
        self.prev_layer.bias.data = self.prev_layer.bias.data / r.view(-1) - b
        self.weight.data = self.weight.data * r.view(1, -1)
        if self.bias is not None:
            self.bias.data = self.bias.data + torch.mm(self.weight.data, b.reshape(-1, 1)).reshape(-1)
        else:
            self.bias = nn.Parameter(torch.zeros(self.out_features))
            self.bias.data = torch.mm(self.weight.data, b.reshape(-1, 1)).reshape(-1)
        return r, b, target_channel_scale, target_channel_zero_point
        
    def reparam(self):
        r, b, target_channel_scale, target_channel_zero_point = self.reparam_step1()
        self.raw_input = (self.raw_input.cuda() / r - b).cpu()
        del self.a_quantizer.scale, self.a_quantizer.zero_point
        self.a_quantizer.channel_wise = False
        self.a_quantizer.scale = nn.Parameter(target_channel_scale)
        self.a_quantizer.zero_point = nn.Parameter(target_channel_zero_point)
        AsymmetricallyBatchingQuantLinear.hyperparameter_searching(self)
