from csv import Sniffer
import torch
import time
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

from .utils.hadamard import Transform_Dict, next_power_of_2
from .utils.tensor_visualizer import viz
transform_cache = Transform_Dict()


def fake_groupwise_token_asymmetric_quantization( ####
    input: torch.Tensor, quantize_bit, group_size=128
):
    batch, num_head, seq_len, sep_dim = input.shape
    dtype = input.dtype
    input = (
        input.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, sep_dim * num_head)
    ).float()
    num_groups = (sep_dim * num_head) // group_size
    if num_groups * group_size != input.shape[-1]:
        raise ValueError("group_size should be a factor of the last dimension size")

    input_in_groups = input.view(batch, seq_len, num_groups, group_size)

    mx, mn = input_in_groups.max(dim=-1)[0], input_in_groups.min(dim=-1)[0]
    mx, mn = mx.unsqueeze(-1), mn.unsqueeze(-1)

    scale = (mx - mn) / (2**quantize_bit - 1)
    input_in_groups = (input_in_groups - mn) / scale
    input_in_groups = F.relu(input_in_groups)
    rounded_input_in_groups = input_in_groups.round_()
    dequantized_input_in_groups = rounded_input_in_groups * scale + mn
    dequantized_input = dequantized_input_in_groups.view(
        batch, seq_len, num_head, sep_dim
    )
    dequantized_input = dequantized_input.permute(0, 2, 1, 3)
    dequantized_input = dequantized_input.type(dtype)
    # reshape the input back to its original shape
    input = input.view(batch, seq_len, num_head, sep_dim)
    input = input.permute(0, 2, 1, 3).contiguous().type(dtype)
    return dequantized_input

def fake_groupwise_channel_asymmetric_quantization(
    input: torch.Tensor, quantize_bit, group_size=128
):
    batch, num_head, seq_len, sep_dim = input.shape
    dtype = input.dtype
    # group_size = 128
    input = (
        input.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, sep_dim * num_head)
    )
    input = input.view(batch, seq_len, num_head * sep_dim)
    group_num = input.shape[1] // group_size

    fixed_length = int(group_num * group_size)
    fixed_input = input[:,:fixed_length,:]
    residual_input = input[:,fixed_length:,:]
    fixed_input = fixed_input.view(batch,group_num, group_size, num_head * sep_dim)
    mx, mn = fixed_input.max(dim=-2)[0], fixed_input.min(dim=-2)[0]
    mx, mn = mx.unsqueeze(-2), mn.unsqueeze(-2)
    
    scale = (mx - mn) / (2**quantize_bit - 1)
    quantized_input = (fixed_input - mn) / scale
    quantized_input = F.relu(quantized_input)
    rounded_input = quantized_input.round_()
    dequantized_input = rounded_input * scale + mn
    dequantized_input = dequantized_input.view(batch,group_num * group_size,num_head * sep_dim)
    dequantized_input = torch.cat((dequantized_input, residual_input), dim=1)
    dequantized_input = dequantized_input.view(batch, seq_len,num_head, sep_dim)
    dequantized_input = dequantized_input.permute(0, 2, 1, 3)
    dequantized_input = dequantized_input.type(dtype)
    # reshape the input back to its original shape

    input = input.view(batch, seq_len, num_head, sep_dim)
    input = input.permute(0, 2, 1, 3).contiguous().type(dtype)
    return dequantized_input

def fake_groupwise_channel_fp8_quantization(
    input: torch.Tensor, fp8_format="e4m3", group_size=128
):
    """
    FP8 quantization with group-wise scale operations.
    Quantizes values in groups using FP8 format with per-group scaling.
    Note: Subnormal values are flushed to zero.
    
    Args:
        input: Input tensor of shape [batch, num_head, seq_len, sep_dim]
        fp8_format: FP8 format, either "e4m3" or "e5m2"
        group_size: Size of each group for group-wise quantization
    
    Returns:
        Dequantized tensor in original dtype
    """
    batch, num_head, seq_len, sep_dim = input.shape
    dtype = input.dtype
    
    # FP8 format definitions
    if fp8_format == "e4m3":
        # E4M3: 1 sign bit, 4 exponent bits, 3 mantissa bits
        exp_bits = 4
        mantissa_bits = 3
        bias = 7  # 2^(4-1) - 1
        # max_normal = (2 - 2**-3) * 2**7 = 1.875 * 128 = 240.0
        max_normal = 240.0
        min_normal = 2**(-6)  # smallest positive normal number
    elif fp8_format == "e5m2":
        # E5M2: 1 sign bit, 5 exponent bits, 2 mantissa bits
        exp_bits = 5
        mantissa_bits = 2
        bias = 15  # 2^(5-1) - 1
        max_normal = 57344.0  # largest normal number in E5M2
        min_normal = 2**(-14)  # smallest positive normal number
    else:
        raise ValueError(f"Unsupported FP8 format: {fp8_format}")
    
    # Reshape input for group-wise processing
    input_reshaped = input.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, sep_dim * num_head)
    group_num = input_reshaped.shape[1] // group_size

    if group_num == 0:
        return input.clone()

    fixed_length = int(group_num * group_size)
    main_part = input_reshaped[:,:fixed_length,:]
    residual_part = input_reshaped[:,fixed_length:,:]
    
    # Reshape into groups
    fixed_input = main_part.view(batch, group_num, group_size, num_head * sep_dim)
    float_input = fixed_input.float()
    
    # Calculate per-group scale factors
    # Find the maximum absolute value in each group
    max_abs_vals = torch.max(torch.abs(float_input), dim=2, keepdim=True)[0]
    
    # Calculate scale to fit the group's range into FP8 range
    # Add small epsilon to avoid division by zero
    epsilon = 1e-8
    group_scale = max_normal / (max_abs_vals + epsilon)
    
    # Apply group-wise scaling
    scaled_input = float_input * group_scale
    
    # Apply FP8 quantization to scaled values
    # 1. Clipping to FP8 range (should be mostly unnecessary after scaling)
    clipped_input = torch.clamp(scaled_input, -max_normal, max_normal)
    
    # Create a mask for values that should be zero (originally zero or subnormal)
    zero_mask = torch.abs(clipped_input) < min_normal

    # 2. FP8 quantization simulation for non-zero values
    # Extract sign and work with absolute values
    sign = torch.sign(clipped_input)
    abs_clipped = torch.abs(clipped_input)
    
    # Initialize tensors for results
    exponent_unbiased = torch.zeros_like(abs_clipped)
    quantized_mantissa = torch.zeros_like(abs_clipped)
    
    # Process only non-zero values to avoid log(0)
    non_zero_mask = ~zero_mask
    
    if torch.any(non_zero_mask):
        # Calculate exponent (biased) for non-zero values
        log2_val = torch.log2(abs_clipped[non_zero_mask])
        exp_unbiased_non_zero = torch.floor(log2_val)
        exponent_unbiased[non_zero_mask] = exp_unbiased_non_zero
        
        # Calculate mantissa for non-zero values
        mantissa_scale_fp8 = 2**mantissa_bits
        normalized_mantissa = abs_clipped[non_zero_mask] / (2**exp_unbiased_non_zero)
        mantissa_fraction = (normalized_mantissa - 1.0) * mantissa_scale_fp8
        # Use round-to-nearest for less bias
        quantized_mantissa_non_zero = torch.round(mantissa_fraction).clamp(0, mantissa_scale_fp8 - 1)
        quantized_mantissa[non_zero_mask] = quantized_mantissa_non_zero

    # 3. Dequantization (reconstruct FP8 value)
    mantissa_scale_fp8 = 2**mantissa_bits
    reconstructed_mantissa = 1.0 + quantized_mantissa / mantissa_scale_fp8
    dequantized_abs = reconstructed_mantissa * (2**exponent_unbiased)
    dequantized_scaled = sign * dequantized_abs
    
    # Handle zero case
    dequantized_scaled[zero_mask] = 0.0
    
    # Apply inverse group-wise scaling to get back to original range
    dequantized_input = dequantized_scaled / group_scale
    
    dequantized_main_part = dequantized_input.view(batch, group_num * group_size, num_head * sep_dim)
    dequantized_input = torch.cat((dequantized_main_part, residual_part.to(dequantized_main_part.dtype)), dim=1)

    # Reshape back to original shape
    dequantized_input = dequantized_input.view(batch, seq_len, num_head, sep_dim)
    dequantized_input = dequantized_input.permute(0, 2, 1, 3).contiguous()
    dequantized_input = dequantized_input.type(dtype)
    
    return dequantized_input

def fake_groupwise_token_fp8_quantization(
    input: torch.Tensor, fp8_format="e4m3", group_size=128
):
    """
    FP8 quantization with token-wise grouping (similar to fake_groupwise_token_asymmetric_quantization).
    Groups are formed along the channel dimension (num_head * sep_dim).
    
    Args:
        input: Input tensor of shape [batch, num_head, seq_len, sep_dim]
        fp8_format: FP8 format, either "e4m3" or "e5m2"
        group_size: Size of each group for group-wise quantization
    
    Returns:
        Dequantized tensor in original dtype
    """
    batch, num_head, seq_len, sep_dim = input.shape
    dtype = input.dtype
    
    # FP8 format definitions
    if fp8_format == "e4m3":
        # E4M3: 1 sign bit, 4 exponent bits, 3 mantissa bits
        exp_bits = 4
        mantissa_bits = 3
        bias = 7  # 2^(4-1) - 1
        max_normal = 240.0
        min_normal = 2**(-6)  # smallest positive normal number
    elif fp8_format == "e5m2":
        # E5M2: 1 sign bit, 5 exponent bits, 2 mantissa bits
        exp_bits = 5
        mantissa_bits = 2
        bias = 15  # 2^(5-1) - 1
        max_normal = 57344.0  # largest normal number in E5M2
        min_normal = 2**(-14)  # smallest positive normal number
    else:
        raise ValueError(f"Unsupported FP8 format: {fp8_format}")
    
    # Reshape input for token-wise grouping
    input_reshaped = input.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, sep_dim * num_head)
    float_input = input_reshaped.float()
    
    # Form groups along the channel dimension
    num_groups = (sep_dim * num_head) // group_size
    
    if num_groups == 0:
        return input.clone()

    fixed_length = int(num_groups * group_size)
    main_part = float_input[...,:fixed_length]
    residual_part = float_input[...,fixed_length:]

    input_in_groups = main_part.view(batch, seq_len, num_groups, group_size)
    
    # Calculate per-group scale factors
    # Find the maximum absolute value in each group
    max_abs_vals = torch.max(torch.abs(input_in_groups), dim=-1, keepdim=True)[0]
    
    # Calculate scale to fit the group's range into FP8 range
    # Add small epsilon to avoid division by zero
    epsilon = 1e-8
    group_scale = max_normal / (max_abs_vals + epsilon)
    
    # Apply group-wise scaling
    scaled_input = input_in_groups * group_scale
    
    # Apply FP8 quantization to scaled values
    # 1. Clipping to FP8 range (should be mostly unnecessary after scaling)
    clipped_input = torch.clamp(scaled_input, -max_normal, max_normal)
    
    # Create a mask for values that should be zero (originally zero or subnormal)
    zero_mask = torch.abs(clipped_input) < min_normal

    # 2. FP8 quantization simulation for non-zero values
    # Extract sign and work with absolute values
    sign = torch.sign(clipped_input)
    abs_clipped = torch.abs(clipped_input)
    
    # Initialize tensors for results
    exponent_unbiased = torch.zeros_like(abs_clipped)
    quantized_mantissa = torch.zeros_like(abs_clipped)
    
    # Process only non-zero values to avoid log(0)
    non_zero_mask = ~zero_mask
    
    if torch.any(non_zero_mask):
        # Calculate exponent (biased) for non-zero values
        log2_val = torch.log2(abs_clipped[non_zero_mask])
        exp_unbiased_non_zero = torch.floor(log2_val)
        exponent_unbiased[non_zero_mask] = exp_unbiased_non_zero
        
        # Calculate mantissa for non-zero values
        mantissa_scale_fp8 = 2**mantissa_bits
        normalized_mantissa = abs_clipped[non_zero_mask] / (2**exp_unbiased_non_zero)
        mantissa_fraction = (normalized_mantissa - 1.0) * mantissa_scale_fp8
        # Use round-to-nearest for less bias
        quantized_mantissa_non_zero = torch.round(mantissa_fraction).clamp(0, mantissa_scale_fp8 - 1)
        quantized_mantissa[non_zero_mask] = quantized_mantissa_non_zero

    # 3. Dequantization (reconstruct FP8 value)
    mantissa_scale_fp8 = 2**mantissa_bits
    reconstructed_mantissa = 1.0 + quantized_mantissa / mantissa_scale_fp8
    dequantized_abs = reconstructed_mantissa * (2**exponent_unbiased)
    dequantized_scaled = sign * dequantized_abs
    
    # Handle zero case
    dequantized_scaled[zero_mask] = 0.0
    
    # Apply inverse group-wise scaling to get back to original range
    dequantized_input_in_groups = dequantized_scaled / group_scale
    
    dequantized_main_part = dequantized_input_in_groups.view(batch, seq_len, fixed_length)
    dequantized_input = torch.cat((dequantized_main_part, residual_part.to(dequantized_main_part.dtype)), dim=-1)

    # Reshape back to original shape
    dequantized_input = dequantized_input.view(
        batch, seq_len, num_head, sep_dim
    )
    dequantized_input = dequantized_input.permute(0, 2, 1, 3).contiguous()
    dequantized_input = dequantized_input.type(dtype)
    
    return dequantized_input

def fake_poweriteration_group(input: torch.Tensor, loop, rank, device, p_base_input, q_base_input):
    # input size [batch,num_head,seq_len,model_dim/num_head]
    # -> [batch,seq_len,model_dim] -> [batch * seq_len,model_dim]
    # p_base = torch.rand(input.shape[3] * input.shape[1], rank).to(device)
    # q_base = torch.rand(input.shape[0] * input.shape[2], rank).to(device)
    dtype = input.dtype
    batch, dim1, dim2, dim3 = input.shape

    output = input.float().clone()
    if p_base_input is not None and q_base_input is not None:
        p_base = p_base_input.clone().float()
        q_base = q_base_input.clone().float()
    else:
        p_base = torch.rand(batch,dim1,dim3, rank).to(device)
        q_base = torch.rand(batch,dim1,dim2, rank).to(device)
    # 3 calculation = loop * (matmul) + 2 * qrO(n^2)
    for i in range(loop):
        p_base_prev = p_base.clone()
        q_base_prev = q_base.clone()
        if i == loop - 1:
            p_base = torch.linalg.qr(p_base).Q
        q_base = output @ p_base
        if i == loop - 1:
            q_base = torch.linalg.qr(q_base).Q
        p_base = output.mT @ q_base
        if (q_base.mean() > 1e15 or p_base.mean() > 1e15) and i != loop - 1:
            p_base_prev = torch.linalg.qr(p_base_prev).Q
            q_base_prev = output @ p_base_prev
            q_base_prev = torch.linalg.qr(q_base_prev).Q
            p_base_prev = output.mT @ q_base_prev
            p_base = p_base_prev
            q_base = q_base_prev
            break
    output = q_base @ p_base.mT
    output = output.view(batch, dim1, dim2, dim3)

    output = output.type(dtype)

    # if output.isnan().any():
    #     return input
    return output

def fake_poweriteration_group_svd(input: torch.Tensor, loop, rank, device, p_base_input, q_base_input):
    # input size [batch,num_head,seq_len,model_dim/num_head]
    # -> returns U, sigma, VT for SVD approximation
    dtype = input.dtype
    batch, dim1, dim2, dim3 = input.shape

    output = input.float().clone()
    if p_base_input is not None and q_base_input is not None:
        p_base = p_base_input.clone().float()
        q_base = q_base_input.clone().float()
    else:
        p_base = torch.rand(batch,dim1,dim3, rank).to(device)
        q_base = torch.rand(batch,dim1,dim2, rank).to(device)
    # 3 calculation = loop * (matmul) + 2 * qrO(n^2)
    for i in range(loop):
        p_base_prev = p_base.clone()
        q_base_prev = q_base.clone()
        if i == loop - 1:
            p_base = torch.linalg.qr(p_base).Q
        q_base = output @ p_base
        if i == loop - 1:
            q_base = torch.linalg.qr(q_base).Q
        p_base = output.mT @ q_base
        if (q_base.mean() > 1e15 or p_base.mean() > 1e15) and i != loop - 1:
            p_base_prev = torch.linalg.qr(p_base_prev).Q
            q_base_prev = output @ p_base_prev
            q_base_prev = torch.linalg.qr(q_base_prev).Q
            p_base_prev = output.mT @ q_base_prev
            p_base = p_base_prev
            q_base = q_base_prev
            break
    return q_base, p_base.mT


def fake_groupwise_channel_asymmetric_quantization_cluster(input,cluster_num,group_size=128):
    batch, num_head, seq_len, sep_dim = input.shape
    dtype = input.dtype
    # group_size = 128
    input = (
        input.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, sep_dim * num_head)
    )
    input = input.view(batch, seq_len, num_head * sep_dim)
    group_num = input.shape[1] // group_size
    fixed_length = int(group_num * group_size)
    fixed_input = input[:,:fixed_length,:]
    residual_input = input[:,fixed_length:,:]
    fixed_input = fixed_input.view(batch,group_num, group_size, num_head * sep_dim)
    mx, mn = fixed_input.max(dim=-2)[0], fixed_input.min(dim=-2)[0]
    mx, mn = mx.unsqueeze(-2), mn.unsqueeze(-2)

    scale = (mx - mn) / cluster_num
    quantized_input = (fixed_input - mn) / scale
    quantized_input = F.relu(quantized_input)
    rounded_input = quantized_input.round_()
    dequantized_input = rounded_input * scale + mn
    dequantized_input = dequantized_input.view(batch,group_num * group_size,num_head * sep_dim)
    dequantized_input = torch.cat((dequantized_input, residual_input), dim=1)
    dequantized_input = dequantized_input.view(batch, seq_len,num_head, sep_dim)
    dequantized_input = dequantized_input.permute(0, 2, 1, 3)
    dequantized_input = dequantized_input.type(dtype)
    # reshape the input back to its original shape

    input = input.view(batch, seq_len, num_head, sep_dim)
    input = input.permute(0, 2, 1, 3).contiguous().type(dtype)
    return dequantized_input

def fake_groupwise_token_asymmetric_quantization_cluster(input,cluster_num,group_size=128):
    batch, num_head, seq_len, sep_dim = input.shape
    dtype = input.dtype
    input = (
        input.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, sep_dim * num_head)
    )
    num_groups = (sep_dim * num_head) // group_size
    if num_groups * group_size != input.shape[-1]:
        raise ValueError("group_size should be a factor of the last dimension size")

    input_in_groups = input.view(batch, seq_len, num_groups, group_size)

    mx, mn = input_in_groups.max(dim=-1)[0], input_in_groups.min(dim=-1)[0]
    mx, mn = mx.unsqueeze(-1), mn.unsqueeze(-1)

    scale = (mx - mn) / cluster_num
    input_in_groups = (input_in_groups - mn) / scale
    input_in_groups = F.relu(input_in_groups)
    rounded_input_in_groups = input_in_groups.round_()
    dequantized_input_in_groups = rounded_input_in_groups * scale + mn
    dequantized_input = dequantized_input_in_groups.view(
        batch, seq_len, num_head, sep_dim
    )
    dequantized_input = dequantized_input.permute(0, 2, 1, 3)
    dequantized_input = dequantized_input.type(dtype)
    # reshape the input back to its original shape
    input = input.view(batch, seq_len, num_head, sep_dim)
    input = input.permute(0, 2, 1, 3).contiguous().type(dtype)
    return dequantized_input

########################################################################################

def gears_channelQ(input, quantize_bit, group_size=128,sparsity=0.0):
    output = input.float()
    batch, num_head, seq_len, sep_dim = input.shape
    element_num = batch * num_head * seq_len * sep_dim
    sparsity_num = int(element_num * sparsity)
    sparsity_pertoken = int(sparsity_num / batch / seq_len/2)
    
    # Remove outliers
    output = (
        output.permute(0, 1, 3, 2).contiguous().view(batch, sep_dim * num_head, seq_len)
    )
    smallest_value, smallest_indices = torch.topk(output, sparsity_pertoken, dim=-1, largest=False)
    largest_value, largest_indices = torch.topk(output, sparsity_pertoken, dim=-1, largest=True)
    average = output.mean(dim=-1, keepdim=True)
    expanded_average = average.expand_as(output)
    output.scatter_(-1, smallest_indices, expanded_average.gather(-1, smallest_indices))
    output.scatter_(-1, largest_indices, expanded_average.gather(-1, largest_indices))
    output = output.view(batch, num_head, sep_dim, seq_len).permute(0, 1, 3, 2)

    # Quantization
    output = fake_groupwise_channel_asymmetric_quantization_cluster(
        output, 2 ** quantize_bit - 1, group_size)

    # Restore the original values at the smallest and largest k indices
    output = (
        output.permute(0, 1, 3, 2).contiguous().view(batch, sep_dim * num_head, seq_len)
    )
    output.scatter_(-1, smallest_indices, smallest_value)
    output.scatter_(-1, largest_indices, largest_value)
    
    output = output.view(batch, num_head, sep_dim, seq_len).permute(0, 1, 3, 2)
    return output
    
def gears_tokenQ(input, quantize_bit, group_size=128,sparsity=0.0):
    output = input.float()
    batch, num_head, seq_len, sep_dim = output.shape
    element_num = batch * num_head * seq_len * sep_dim
    sparsity_num = int(element_num * sparsity)
    sparsity_pertoken = int(sparsity_num / batch / seq_len/2)

    # Remove outliers
    output = (
        output.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, sep_dim * num_head)
    )
    smallest_value, smallest_indices = torch.topk(output, sparsity_pertoken, dim=-1, largest=False)
    largest_value, largest_indices = torch.topk(output, sparsity_pertoken, dim=-1, largest=True)
    average = output.mean(dim=-1, keepdim=True)
    expanded_average = average.expand_as(output)
    output.scatter_(-1, smallest_indices, expanded_average.gather(-1, smallest_indices))
    output.scatter_(-1, largest_indices, expanded_average.gather(-1, largest_indices))
    output = output.view(batch, seq_len, num_head, sep_dim).permute(0, 2, 1, 3)

    # Quantization
    output = fake_groupwise_token_asymmetric_quantization_cluster(
        output, 2 ** quantize_bit - 1, group_size)
    
    # Restore the original values at the smallest and largest k indices
    output = (
        output.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, sep_dim * num_head)
    )
    output.scatter_(-1, smallest_indices, smallest_value)
    output.scatter_(-1, largest_indices, largest_value)
    
    output = output.view(batch, seq_len, num_head, sep_dim).permute(0, 2, 1, 3)
    return output

def gears_hadamard_channelQ(input, quantize_bit, group_size=128,sparsity=0.0):
    output = input.float()
    batch, num_head, seq_len, sep_dim = input.shape
    element_num = batch * num_head * seq_len * sep_dim
    sparsity_num = int(element_num * sparsity)
    sparsity_pertoken = int(sparsity_num / batch / seq_len/2)

    # Reshape for removing outliers
    output = (
        output.permute(0, 1, 3, 2).contiguous().view(batch, sep_dim * num_head, seq_len)
    )

    # Prepaer Hadamard transform
    b, sn, l = output.shape
    H = torch.from_numpy(transform_cache.get_or_register("hadamard", l)).to(input.device, dtype=output.dtype)
    size_m = next_power_of_2(l)

    # Apply Hadamard transform
    padded_output = F.pad(output, (0, size_m - l))
    padded_output = (padded_output @ H) / np.sqrt(size_m)
    output = padded_output[:, :, :l]

    # Remove the smallest and largest k elements
    smallest_value, smallest_indices = torch.topk(output, sparsity_pertoken, dim=-1, largest=False)
    largest_value, largest_indices = torch.topk(output, sparsity_pertoken, dim=-1, largest=True)
    average = output.mean(dim=-1, keepdim=True)
    expanded_average = average.expand_as(output)
    output.scatter_(-1, smallest_indices, expanded_average.gather(-1, smallest_indices))
    output.scatter_(-1, largest_indices, expanded_average.gather(-1, largest_indices))
    
    # Quantization
    output = output.view(batch, num_head, sep_dim, seq_len).permute(0, 1, 3, 2)
    output = fake_groupwise_channel_asymmetric_quantization_cluster(
        output, 2 ** quantize_bit - 1, group_size)
    
    # Restore the original values at the smallest and largest k indices
    output = (
        output.permute(0, 1, 3, 2).contiguous().view(batch, sep_dim * num_head, seq_len)
    )
    output.scatter_(-1, smallest_indices, smallest_value)
    output.scatter_(-1, largest_indices, largest_value)

    # De-transform
    padded_output = F.pad(output, (0, size_m - l))
    padded_output = (padded_output @ H) / np.sqrt(size_m)
    output = padded_output[:, :, :l]
    
    output = output.view(batch, num_head, sep_dim, seq_len).permute(0, 1, 3, 2)
    return output

def gears_hadamard_tokenQ(input, quantize_bit, group_size=128,sparsity=0.0):
    output = input.float()
    batch, num_head, seq_len, sep_dim = output.shape
    element_num = batch * num_head * seq_len * sep_dim

    sparsity_num = int(element_num * sparsity)
    sparsity_pertoken = int(sparsity_num / batch / seq_len/2)

    # Reshape for removing outliers
    output = (
        output.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, sep_dim * num_head)
    )

    # Prepare Hadamard transform
    b, l, sn = output.shape
    size_m = next_power_of_2(sn)
    H = torch.from_numpy(transform_cache.get_or_register("hadamard", sn)).to(input.device, dtype=output.dtype)

    # Apply Hadamard transform
    padded_output = F.pad(output, (0, size_m - sn))
    output = (padded_output @ H) / np.sqrt(size_m)
    output = output[:, :, :sn]

    # Remove the smallest and largest k elements
    smallest_value, smallest_indices = torch.topk(output, sparsity_pertoken, dim=-1, largest=False)
    largest_value, largest_indices = torch.topk(output, sparsity_pertoken, dim=-1, largest=True)
    average = output.mean(dim=-1, keepdim=True)
    expanded_average = average.expand_as(output)
    output.scatter_(-1, smallest_indices, expanded_average.gather(-1, smallest_indices))
    output.scatter_(-1, largest_indices, expanded_average.gather(-1, largest_indices))

    # Quantization
    output = output.view(batch, seq_len, num_head, sep_dim).permute(0, 2, 1, 3)
    output = fake_groupwise_token_asymmetric_quantization_cluster(
        output, 2 ** quantize_bit - 1, group_size)

    # Restore the original values at the smallest and largest k indices
    output = (
        output.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, sep_dim * num_head)
    )
    output.scatter_(-1, smallest_indices, smallest_value)
    output.scatter_(-1, largest_indices, largest_value)

    # De-transform
    padded_output = F.pad(output, (0, size_m - sn))
    padded_output = (padded_output @ H) / np.sqrt(size_m)
    output = padded_output[:, :, :sn]

    output = output.view(batch, seq_len, num_head, sep_dim).permute(0, 2, 1, 3)
    return output
    
def pcc_svd_channel(input, rank = 0, loop=1):
    output = input.float()
    batch, num_head, seq_len, sep_dim = input.shape

     # Make Covariance Matrix
    cov_matrix = output.mT @ output
    jitter = 1e-5 * torch.eye(sep_dim, device=input.device, dtype=input.dtype)
    cov_matrix += jitter

    # SVD
    try:
        S = torch.linalg.cholesky(cov_matrix)
        S_inv = torch.linalg.inv(S)

        whitened_output = output @ S_inv
        whitend_output_lr = fake_poweriteration_group(whitened_output, loop, rank, input.device, None, None)
        output = whitend_output_lr @ S
    except torch._C._LinAlgError as e:
        size_m = next_power_of_2(sep_dim)
        H = torch.from_numpy(transform_cache.get_or_register("hadamard", sep_dim)).to(input.device, dtype=output.dtype)
        padded_output = F.pad(output, (0, size_m - sep_dim))
        padded_output = (padded_output @ H)
        padded_output = fake_poweriteration_group(padded_output, loop, rank, input.device, None, None)
        padded_output = (padded_output @ H) / size_m
        output = padded_output[:, :, :, :sep_dim]
    
    return output

def pcc_svd_token(input, rank = 0, loop=1):
    output = input.float()
    batch, num_head, seq_len, sep_dim = input.shape

     # Make Covariance Matrix
    cov_matrix = output @ output.mT
    jitter = 1e-5 * torch.eye(seq_len, device=input.device, dtype=input.dtype)
    cov_matrix += jitter

    try:
        S = torch.linalg.cholesky(cov_matrix)
        S_inv = torch.linalg.inv(S)
        whitened_output = S_inv @ output
        whitend_output_lr = fake_poweriteration_group(whitened_output, loop, rank, input.device, None, None)
        output = S @ whitend_output_lr
    except torch._C._LinAlgError as e:
        output = fake_poweriteration_group(output, loop, rank, input.device, None, None)
    
    return output

def outlier_removal_hla_channel(input, quantize_bit, group_size=128,sparsity=0.0, hla_rank=0):
    output = input.float()
    batch, num_head, seq_len, sep_dim = input.shape
    element_num = batch * num_head * seq_len * sep_dim
    sparsity_num = int(element_num * sparsity)
    sparsity_pertoken = int(sparsity_num / batch / seq_len/2)
    
    # Remove outliers
    output = (
        output.permute(0, 1, 3, 2).contiguous().view(batch, sep_dim * num_head, seq_len)
    )
    smallest_value, smallest_indices = torch.topk(output, sparsity_pertoken, dim=-1, largest=False)
    largest_value, largest_indices = torch.topk(output, sparsity_pertoken, dim=-1, largest=True)
    average = output.mean(dim=-1, keepdim=True)
    expanded_average = average.expand_as(output)
    output.scatter_(-1, smallest_indices, expanded_average.gather(-1, smallest_indices))
    output.scatter_(-1, largest_indices, expanded_average.gather(-1, largest_indices))
    output = output.view(batch, num_head, sep_dim, seq_len).permute(0, 1, 3, 2)

    # Hadamard Low-rank Approximation
    size_m = next_power_of_2(sep_dim)
    H = torch.from_numpy(transform_cache.get_or_register("low_rank", sep_dim, r=hla_rank, freq="high")).to(input.device, dtype=output.dtype)
    padded_output = F.pad(output, (0, size_m - sep_dim))
    padded_output = (padded_output @ H.mT)
    padded_output = (padded_output @ H) / size_m
    output = padded_output[:, :, :, :sep_dim]
    
    # Restore the original values at the smallest and largest k indices
    output = (
        output.permute(0, 1, 3, 2).contiguous().view(batch, sep_dim * num_head, seq_len)
    )
    output.scatter_(-1, smallest_indices, smallest_value)
    output.scatter_(-1, largest_indices, largest_value)
    
    output = output.view(batch, num_head, sep_dim, seq_len).permute(0, 1, 3, 2)
    return output

def outlier_removal_hla_token(input, quantize_bit, group_size=128,sparsity=0.0, hla_rank=0):
    output = input.float()
    batch, num_head, seq_len, sep_dim = output.shape
    element_num = batch * num_head * seq_len * sep_dim
    sparsity_num = int(element_num * sparsity)
    sparsity_pertoken = int(sparsity_num / batch / seq_len/2)

    # Remove outliers
    output = (
        output.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, sep_dim * num_head)
    )
    smallest_value, smallest_indices = torch.topk(output, sparsity_pertoken, dim=-1, largest=False)
    largest_value, largest_indices = torch.topk(output, sparsity_pertoken, dim=-1, largest=True)
    average = output.mean(dim=-1, keepdim=True)
    expanded_average = average.expand_as(output)
    output.scatter_(-1, smallest_indices, expanded_average.gather(-1, smallest_indices))
    output.scatter_(-1, largest_indices, expanded_average.gather(-1, largest_indices))
    output = output.view(batch, seq_len, num_head, sep_dim).permute(0, 2, 1, 3)

    # Hadamard Low-rank Approximation
    size_m = next_power_of_2(sep_dim)
    H = torch.from_numpy(transform_cache.get_or_register("low_rank", sep_dim, r=hla_rank, freq="high")).to(input.device, dtype=output.dtype)
    padded_output = F.pad(output, (0, size_m - sep_dim))
    padded_output = (padded_output @ H.mT)
    padded_output = (padded_output @ H) / size_m
    output = padded_output[:, :, :, :sep_dim]
    
    # Restore the original values at the smallest and largest k indices
    output = (
        output.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, sep_dim * num_head)
    )
    output.scatter_(-1, smallest_indices, smallest_value)
    output.scatter_(-1, largest_indices, largest_value)
    
    output = output.view(batch, seq_len, num_head, sep_dim).permute(0, 2, 1, 3)
    return output

def outlier_removal_svd_channel(input ,sparsity=0.0, loop=1, rank=0, transform="", group_size=0, lora_quant=False):
    output = input.float()
    batch, num_head, seq_len, sep_dim = input.shape
    element_num = batch * num_head * seq_len * sep_dim
    sparsity_num = int(element_num * sparsity)
    sparsity_pertoken = int(sparsity_num / batch / seq_len/2)
    
    # Remove outliers
    output = (
        output.permute(0, 1, 3, 2).contiguous().view(batch, sep_dim * num_head, seq_len)
    )
    smallest_value, smallest_indices = torch.topk(output, sparsity_pertoken, dim=-1, largest=False)
    largest_value, largest_indices = torch.topk(output, sparsity_pertoken, dim=-1, largest=True)
    average = output.mean(dim=-1, keepdim=True)
    expanded_average = average.expand_as(output)
    output.scatter_(-1, smallest_indices, expanded_average.gather(-1, smallest_indices))
    output.scatter_(-1, largest_indices, expanded_average.gather(-1, largest_indices))
    output = output.view(batch, num_head, sep_dim, seq_len).permute(0, 1, 3, 2)

    # Quantization
    # output = fake_groupwise_channel_asymmetric_quantization(
    #     output, quantize_bit=8, group_size=group_size
    # )
    if lora_quant:
        rank = rank * 4

    if transform == "cov":
        # SVD
        try:
            # Make Covariance Matrix
            cov_matrix = output.mT @ output
            jitter = 1e-5 * torch.eye(sep_dim, device=input.device, dtype=input.dtype)
            cov_matrix += jitter

            S = torch.linalg.cholesky(cov_matrix)
            S_inv = torch.linalg.inv(S)

            whitened_output = output @ S_inv
            whitend_output_lr = fake_poweriteration_group(whitened_output, loop, rank, input.device, None, None)
            output = whitend_output_lr @ S
        except torch._C._LinAlgError as e:
            # output = fake_poweriteration_group(output, loop, rank, input.device, None, None)
            size_m = next_power_of_2(sep_dim)
            H = torch.from_numpy(transform_cache.get_or_register("hadamard", sep_dim)).to(input.device, dtype=output.dtype)
            padded_output = F.pad(output, (0, size_m - sep_dim))
            padded_output = (padded_output @ H)
            padded_output = fake_poweriteration_group(padded_output, loop, rank, input.device, None, None)
            padded_output = (padded_output @ H) / size_m
            output = padded_output[:, :, :, :sep_dim]
    elif transform == "hadamard":
        size_m = next_power_of_2(sep_dim)
        H = torch.from_numpy(transform_cache.get_or_register("hadamard", sep_dim)).to(input.device, dtype=output.dtype)
        padded_output = F.pad(output, (0, size_m - sep_dim))
        padded_output = (padded_output @ H)
        padded_output = fake_poweriteration_group(padded_output, loop, rank, input.device, None, None)
        padded_output = (padded_output @ H) / size_m
        output = padded_output[:, :, :, :sep_dim]
    elif transform == "pca":
        # Real SVD
        try:
            cov_matrix = output.mT @ output
            U, sig, UT = torch.linalg.svd(cov_matrix, full_matrices=False)
            output = output @ UT
            output = fake_poweriteration_group(output, loop, rank, input.device, None, None)
            output = output @ U
        except torch._C._LinAlgError as e:
            size_m = next_power_of_2(sep_dim)
            H = torch.from_numpy(transform_cache.get_or_register("hadamard", sep_dim)).to(input.device, dtype=output.dtype)
            padded_output = F.pad(output, (0, size_m - sep_dim))
            padded_output = (padded_output @ H)
            padded_output = fake_poweriteration_group(padded_output, loop, rank, input.device, None, None)
            padded_output = (padded_output @ H) / size_m
            output = padded_output[:, :, :, :sep_dim]
    elif transform == "pca-basis-none":
        cov_matrix = output.mT @ output
        left_vec, right_vec = fake_poweriteration_group_svd(cov_matrix, loop, rank, device=input.device, p_base_input=None, q_base_input=None)
        output = fake_poweriteration_group(output, loop, rank, input.device, right_vec.mT, left_vec)
    elif transform == "pca-basis-hadamard":
        size_m = next_power_of_2(sep_dim)
        H = torch.from_numpy(transform_cache.get_or_register("hadamard", sep_dim)).to(input.device, dtype=output.dtype)
        padded_output = F.pad(output, (0, size_m - sep_dim))
        padded_output = (padded_output @ H) / (size_m ** 0.5)

        # Calculate the covariance matrix
        mean = torch.mean(padded_output, dim=2, keepdim=True)
        centered_output = padded_output - mean
        cov_matrix = centered_output.mT @ centered_output
        
        left_vec, right_vec = fake_poweriteration_group_svd(cov_matrix.clone(), loop, rank, input.device, None, None)
        padded_output = fake_poweriteration_group(padded_output.clone(), loop, rank, input.device, right_vec.mT, left_vec)
        padded_output = (padded_output @ H) / (size_m ** 0.5)
        output = padded_output[:, :, :, :sep_dim]
    elif transform == "happi":
        size_m = next_power_of_2(sep_dim)
        H = torch.from_numpy(transform_cache.get_or_register("hadamard", sep_dim)).to(input.device, dtype=output.dtype)
        padded_output = F.pad(output, (0, size_m - sep_dim))
        
        # calculate covariance matrix in original space
        # mean = torch.mean(padded_output, dim=2, keepdim=True)
        mean = 0.0
        centered_output = padded_output - mean
        cov_matrix = centered_output.mT @ centered_output

        try:
            # cholesky whitening
            S = torch.linalg.cholesky(cov_matrix)
            S_inv = torch.linalg.inv(S)
            whitened_output = padded_output @ S
            
            # happi
            whitened_output = (whitened_output @ H) / (size_m ** 0.5)
            cov_matrix = (H @ cov_matrix @ H) / size_m
            left_vec, right_vec = fake_poweriteration_group_svd(cov_matrix.clone(), loop, rank, input.device, None, None)
            whitened_output_lr = fake_poweriteration_group(whitened_output.clone(), loop, rank, input.device, right_vec.mT, left_vec)
            whitened_output_lr = (whitened_output_lr @ H) / (size_m ** 0.5)

            # reconstruct
            padded_output = whitened_output_lr @ S_inv
        except torch._C._LinAlgError as e:
            padded_output = (padded_output @ H) / (size_m ** 0.5)
            cov_matrix = (H @ cov_matrix @ H) / size_m
            left_vec, right_vec = fake_poweriteration_group_svd(cov_matrix.clone(), loop, rank, input.device, None, None)
            padded_output_lr = fake_poweriteration_group(padded_output.clone(), loop, rank, input.device, right_vec.mT, left_vec)
            padded_output = (padded_output_lr @ H) / (size_m ** 0.5)
        output = padded_output[:, :, :, :sep_dim]
    elif transform == "happi-v2":
        size_m = next_power_of_2(sep_dim)
        H = torch.from_numpy(transform_cache.get_or_register("hadamard", sep_dim)).to(input.device, dtype=output.dtype)
        padded_output = F.pad(output, (0, size_m - sep_dim))
        cov_matrix = padded_output.mT @ padded_output

        whitening_rank = 8
        left_vec_whitening, _ = fake_poweriteration_group_svd(cov_matrix.clone(), loop, whitening_rank, input.device, None, None)
        left_vec_whitening = F.pad(left_vec_whitening, (0, size_m - whitening_rank))
        padded_output = padded_output @ left_vec_whitening
        
        padded_output = (padded_output @ H) / (size_m ** 0.5)
        cov_matrix = (H @ cov_matrix @ H ) / size_m
        left_vec, right_vec = fake_poweriteration_group_svd(cov_matrix.clone(), loop, rank, input.device, None, None)
        padded_output = fake_poweriteration_group(padded_output.clone(), loop, rank, input.device, right_vec.mT, left_vec)
        padded_output = (padded_output @ H) / (size_m ** 0.5)
        
        padded_output = padded_output @ left_vec_whitening.mT
        
        output = padded_output[:, :, :, :sep_dim]

    elif transform == "pca-basis-pca":
        cov_matrix = output.mT @ output
        left_vec, right_vec = fake_poweriteration_group_svd(cov_matrix.clone(), loop, rank, device=input.device, p_base_input=None, q_base_input=None)
        output = output @ left_vec
        output = fake_poweriteration_group(output, loop, rank, input.device, right_vec, left_vec)
        output = output @ left_vec.mT
    elif transform == "none":
        output = fake_poweriteration_group(output, loop, rank, input.device, None, None)
    else:
        raise ValueError(f"Invalid svd transform: {transform}")
    
    # Restore the original values at the smallest and largest k indices
    output = (
        output.permute(0, 1, 3, 2).contiguous().view(batch, sep_dim * num_head, seq_len)
    )
    output.scatter_(-1, smallest_indices, smallest_value)
    output.scatter_(-1, largest_indices, largest_value)
    
    output = output.view(batch, num_head, sep_dim, seq_len).permute(0, 1, 3, 2)

    if lora_quant:
        # output = fake_groupwise_channel_fp8_quantization(output, fp8_format="e5m2", group_size=group_size)
        output = fake_groupwise_channel_asymmetric_quantization(output, quantize_bit=8, group_size=group_size)

    return output

def outlier_removal_svd_token(input, sparsity=0.0, loop=1, rank=0, transform="", group_size=0, lora_quant=False):
    output = input.float()
    batch, num_head, seq_len, sep_dim = output.shape
    element_num = batch * num_head * seq_len * sep_dim
    sparsity_num = int(element_num * sparsity)
    sparsity_pertoken = int(sparsity_num / batch / seq_len/2)

    # Remove outliers
    output = (
        output.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, sep_dim * num_head)
    )
    smallest_value, smallest_indices = torch.topk(output, sparsity_pertoken, dim=-1, largest=False)
    largest_value, largest_indices = torch.topk(output, sparsity_pertoken, dim=-1, largest=True)
    average = output.mean(dim=-1, keepdim=True)
    expanded_average = average.expand_as(output)
    output.scatter_(-1, smallest_indices, expanded_average.gather(-1, smallest_indices))
    output.scatter_(-1, largest_indices, expanded_average.gather(-1, largest_indices))
    output = output.view(batch, seq_len, num_head, sep_dim).permute(0, 2, 1, 3)

    # Quantization
    # output = fake_groupwise_token_asymmetric_quantization(
    #     output, quantize_bit=8, group_size=group_size
    # )
    if lora_quant:
        rank = rank * 4

    if transform == "cov":
        try:
            # SVD
            cov_matrix = output @ output.mT
            jitter = 1e-5 * torch.eye(seq_len, device=input.device, dtype=input.dtype)
            cov_matrix += jitter

            S = torch.linalg.cholesky(cov_matrix)
            S_inv = torch.linalg.inv(S)
            whitened_output = S_inv @ output
            whitend_output_lr = fake_poweriteration_group(whitened_output, loop, rank, input.device, None, None)
            output = S @ whitend_output_lr
        except torch._C._LinAlgError as e:
            output = fake_poweriteration_group(output, loop, rank, input.device, None, None)
            # size_m = next_power_of_2(sep_dim)
            # H = torch.from_numpy(transform_cache.get_or_register("hadamard", sep_dim)).to(input.device, dtype=output.dtype)
            # padded_output = F.pad(output, (0, size_m - sep_dim))
            # padded_output = (padded_output @ H)
            # padded_output = fake_poweriteration_group(padded_output, loop, rank, input.device, None, None)
            # padded_output = (padded_output @ H) / size_m
            # output = padded_output[:, :, :, :sep_dim]
    elif transform == "hadamard":
        size_m = next_power_of_2(sep_dim)
        H = torch.from_numpy(transform_cache.get_or_register("hadamard", sep_dim)).to(input.device, dtype=output.dtype)
        padded_output = F.pad(output, (0, size_m - sep_dim))
        padded_output = (padded_output @ H)
        padded_output = fake_poweriteration_group(padded_output, loop, rank, input.device, None, None)
        padded_output = (padded_output @ H) / size_m
        output = padded_output[:, :, :, :sep_dim]
    elif transform == "pca":
        try:
            cov_matrix = output.mT @ output
            U, sig, UT = torch.linalg.svd(cov_matrix, full_matrices=False)
            output = output @ UT
            output = fake_poweriteration_group(output, loop, rank, input.device, None, None)
            output = output @ U
        except torch._C._LinAlgError as e:
            size_m = next_power_of_2(sep_dim)
            H = torch.from_numpy(transform_cache.get_or_register("hadamard", sep_dim)).to(input.device, dtype=output.dtype)
            padded_output = F.pad(output, (0, size_m - sep_dim))
            padded_output = (padded_output @ H)
            padded_output = fake_poweriteration_group(padded_output, loop, rank, input.device, None, None)
            padded_output = (padded_output @ H) / size_m
            output = padded_output[:, :, :, :sep_dim]
    elif transform == "pca-basis-none":
        cov_matrix = output.mT @ output
        left_vec, right_vec = fake_poweriteration_group_svd(cov_matrix, loop, rank, device=input.device, p_base_input=None, q_base_input=None)
        output = fake_poweriteration_group(output, loop, rank, input.device, right_vec.mT, left_vec)
    elif transform == "pca-basis-hadamard" or transform == "happi" or transform == "happi-v2":
        size_m = next_power_of_2(sep_dim)
        H = torch.from_numpy(transform_cache.get_or_register("hadamard", sep_dim)).to(input.device, dtype=output.dtype)
        padded_output = F.pad(output, (0, size_m - sep_dim))
        padded_output = (padded_output @ H) / (size_m ** 0.5)

        # Calculate the covariance matrix
        mean = torch.mean(padded_output, dim=2, keepdim=True)
        # mean = 0.0
        centered_output = padded_output - mean
        cov_matrix = centered_output.mT @ centered_output
        
        left_vec, right_vec = fake_poweriteration_group_svd(cov_matrix.clone(), loop, rank, input.device, None, None)
        padded_output = fake_poweriteration_group(padded_output.clone(), loop, rank, input.device, right_vec.mT, None)
        padded_output = (padded_output @ H) / (size_m ** 0.5)
        output = padded_output[:, :, :, :sep_dim]
    elif transform == "pca-basis-pca":
        cov_matrix = output.mT @ output
        left_vec, right_vec = fake_poweriteration_group_svd(cov_matrix.clone(), loop, rank, device=input.device, p_base_input=None, q_base_input=None)
        output = output @ left_vec
        output = fake_poweriteration_group(output, loop, rank, input.device, right_vec, left_vec)
        output = output @ left_vec.mT
    elif transform == "none":
        output = fake_poweriteration_group(output, loop, rank, input.device, None, None)
    else:
        raise ValueError(f"Invalid svd transform: {transform}")
    
    # Restore the original values at the smallest and largest k indices
    output = (
        output.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, sep_dim * num_head)
    )
    output.scatter_(-1, smallest_indices, smallest_value)
    output.scatter_(-1, largest_indices, largest_value)
    
    output = output.view(batch, seq_len, num_head, sep_dim).permute(0, 2, 1, 3)

    if lora_quant:
       # output = fake_groupwise_token_fp8_quantization(output, fp8_format="e5m2", group_size=group_size)
       output = fake_groupwise_token_asymmetric_quantization(output, quantize_bit=8, group_size=group_size)

    return output

########################################################################################

def gearslkivi_tokenQ_new(input, quantize_bit, group_size=128,sparsity=0.0,rank = 0,loop=1): ####
    input = input.float()
    cloned_input = input.clone()
    output = gears_tokenQ(input, quantize_bit, group_size,sparsity).half()

    error = cloned_input - output
    error_lr = fake_poweriteration_group(error, loop, rank, input.device, None, None)
    return output + error_lr

def gearslkivi_channelQ_new(input, quantize_bit, group_size=128,sparsity=0.0,rank = 0,loop=1): ####
    input = input.float()
    cloned_input = input.clone()
    output = gears_channelQ(input, quantize_bit, group_size,sparsity).half()

    error = cloned_input - output
    error_lr = fake_poweriteration_group(error, loop, rank, input.device, None, None)
    return output + error_lr
     
def tokenwise_gearlkivi_channelQ(input, quantize_bit, group_size=128,r=0,loop=1): ####
    bsz, num_head, seq_len, sep_dim = input.shape
    cloned_input = input.clone()
    output = fake_groupwise_channel_asymmetric_quantization(
        input, quantize_bit, group_size
    )
    
    error = cloned_input - output
    #### TODO some changes here
    # error = error.permute(0, 1, 3, 2).contiguous().view(bsz, sep_dim * num_head, seq_len)
    # group_num = seq_len // group_size
    # error = error.view(bsz, sep_dim * num_head, group_num, group_size)
    
    error_lr = fake_poweriteration_group(error,
                                loop,
                                r,
                                input.device,
                                None,
                                None,

                                )
    # error_lr = error_lr.view(bsz, sep_dim, num_head, group_num*group_size).permute(0, 2, 3, 1).contiguous().view(bsz, num_head, group_num*group_size, sep_dim)
    
    return output + error_lr

def tokenwise_gearlkivi_tokenQ(input, quantize_bit, group_size=128,r=0,loop=1):
    bsz, num_head, seq_len, sep_dim = input.shape
    cloned_input = input.clone()
    output = fake_groupwise_token_asymmetric_quantization(
        input, quantize_bit, group_size
    )
    error = cloned_input - output
    error_lr = fake_poweriteration_group(error,
                                loop,
                                r,
                                input.device,
                                None,
                                None,
                                )
    return output + error_lr

def pcc_channelQ(input, quantize_bit, group_size=128, sparsity=0.0, rank = 0,loop=1, kv_transform=""):
    input_float = input.float()
    transform = kv_transform if kv_transform is not None else "none"
    output = outlier_removal_svd_channel(input_float.clone(), sparsity, loop, rank, transform=transform, group_size=group_size, lora_quant=True)

    error = input_float - output

    error = gears_channelQ(error.clone(), quantize_bit, group_size, sparsity)

    return (output + error).half()

def pcc_tokenQ(input, quantize_bit, group_size=128, sparsity=0.0, rank = 0,loop=1, kv_transform=""):
    input_float = input.float()
    transform = kv_transform if kv_transform is not None else "none"
    output = outlier_removal_svd_token(input_float.clone(), sparsity, loop, rank, transform=transform, group_size=group_size, lora_quant=True)

    error = input_float - output

    error = gears_tokenQ(error.clone(), quantize_bit, group_size, sparsity)

    return (output + error).half()

def palu_channelQ(input, quantize_bit, group_size=128, sparsity=0.0, rank = 0, loop=1):
    input_float = input.float()
    # rank = input_float.shape[-1] // 2

    # output = outlier_removal_svd_channel(input_float.clone(), sparsity, loop, rank, transform="cov", group_size=group_size)
    # output = fake_poweriteration_group(input_float, loop, rank, input.device, None, None)
    output = pcc_svd_channel(input_float, rank, loop)

    if quantize_bit != 0:
        output = gears_hadamard_channelQ(output, quantize_bit, group_size)

    return output.half()

def palu_tokenQ(input, quantize_bit, group_size=128, sparsity=0.0, rank = 0, loop=1):
    input_float = input.float()
    # rank = input_float.shape[-1] // 2

    # output = outlier_removal_svd_token(input_float.clone(), sparsity, loop, rank, transform="cov", group_size=group_size)
    # output = fake_poweriteration_group(input_float, loop, rank, input.device, None, None)
    output = pcc_svd_token(input_float, rank, loop)
    
    if quantize_bit != 0:
        output = gears_hadamard_tokenQ(output, quantize_bit, group_size)

    return output.half()

########################################################################################

def compress_sandbox(input, quantize_bit, group_size, sparsity=0.0, rank = 0, loop = 1, first_method = None, first_transform = None, second_method = None, second_transform = None, input_tensor = "key", hla_rank = 0):
    input_float = input.float()
    b, h, s, d = input_float.shape

    if first_method == "quantization":
        if first_transform == "hadamard":
            output = gears_hadamard_channelQ(input_float.clone(), quantize_bit, group_size, sparsity) if input_tensor == "key" else gears_hadamard_tokenQ(input_float.clone(), quantize_bit, group_size, sparsity)
        elif first_transform is None:
            output = gears_channelQ(input_float.clone(), quantize_bit, group_size, sparsity) if input_tensor == "key" else gears_tokenQ(input_float.clone(), quantize_bit, group_size, sparsity)
        else:
            raise ValueError(f"Invalid first_transform: {first_transform}")
    elif first_method == "low_rank":
        if first_transform == "hadamard":
            size_m = next_power_of_2(d)
            pad_size = size_m - d
            padded_input = F.pad(input_float, (0, pad_size))
            H = torch.from_numpy(transform_cache.get_or_register("hadamard", d, r=rank)).to(input.device, dtype=input_float.dtype)
            
            H_input = padded_input @ H
            H_input_lr = fake_poweriteration_group(H_input, loop, rank, input.device, None, None)
            input_lr_padded = (H_input_lr @ H) / size_m
            output = input_lr_padded[:, :, :, :d]
        elif first_transform == "cov":
            try:
                if input_tensor == "key":
                    cov_matrix = input_float.mT @ input_float
                    jitter = 1e-5 * torch.eye(d, device=input_float.device, dtype=input_float.dtype)
                    cov_matrix += jitter

                    S = torch.linalg.cholesky(cov_matrix)
                    S_inv = torch.linalg.inv(S)
                    whitened_input = input_float @ S_inv
                    whitend_input_lr = fake_poweriteration_group(whitened_input, loop, rank, input.device, None, None)
                    output = whitend_input_lr @ S
                elif input_tensor == "value":
                    cov_matrix = input_float @ input_float.mT
                    jitter = 1e-5 * torch.eye(s, device=input_float.device, dtype=input_float.dtype)
                    cov_matrix += jitter

                    S = torch.linalg.cholesky(cov_matrix)
                    S_inv = torch.linalg.inv(S)
                    whitened_input = S_inv @ input_float
                    whitend_input_lr = fake_poweriteration_group(whitened_input, loop, rank, input.device, None, None)
                    output = S @ whitend_input_lr
            except torch._C._LinAlgError as e:
                print(f"compress_sandbox first_transform: Cholesky decomposition failed. Error: {e}")
                # Fallback to non-whitened LR approximation on error
                output = fake_poweriteration_group(input_float.clone(), loop, rank, input.device, None, None)
        elif first_transform == "outlier":
            output = outlier_removal_svd_channel(input_float.clone(), sparsity, loop, rank, transform="none") if input_tensor == "key" else outlier_removal_svd_token(input_float.clone(), sparsity, loop, rank, transform="none")
        elif first_transform == "hadamard_outlier":
            output = outlier_removal_svd_channel(input_float.clone(), sparsity, loop, rank, transform="hadamard") if input_tensor == "key" else outlier_removal_svd_token(input_float.clone(), sparsity, loop, rank, transform="hadamard")
        elif first_transform == "pca-basis-none":
            output = outlier_removal_svd_channel(input_float.clone(), sparsity, loop, rank, transform="pca-basis-none", group_size=group_size, lora_quant=True) if input_tensor == "key" else outlier_removal_svd_token(input_float.clone(), sparsity, loop, rank, transform="pca-basis-none", group_size=group_size, lora_quant=True)
        elif first_transform == "pca-basis-hadamard":
            output = outlier_removal_svd_channel(input_float.clone(), sparsity, loop, rank, transform="pca-basis-hadamard", group_size=group_size, lora_quant=True) if input_tensor == "key" else outlier_removal_svd_token(input_float.clone(), sparsity, loop, rank, transform="pca-basis-hadamard", group_size=group_size, lora_quant=True)
        elif first_transform == "happi":
            output = outlier_removal_svd_channel(input_float.clone(), sparsity, loop, rank, transform="happi", group_size=group_size, lora_quant=True) if input_tensor == "key" else outlier_removal_svd_token(input_float.clone(), sparsity, loop, rank, transform="happi", group_size=group_size, lora_quant=True)
        elif first_transform == "happi-v2":
            output = outlier_removal_svd_channel(input_float.clone(), sparsity, loop, rank, transform="happi-v2", group_size=group_size, lora_quant=True) if input_tensor == "key" else outlier_removal_svd_token(input_float.clone(), sparsity, loop, rank, transform="happi-v2", group_size=group_size, lora_quant=True)
        elif first_transform == "hla":
            size_m = next_power_of_2(d)
            pad_size = size_m - d
            padded_input = F.pad(input_float, (0, pad_size))
            H = torch.from_numpy(transform_cache.get_or_register("low_rank", d, r=hla_rank, freq="high")).to(input.device, dtype=input_float.dtype)
            H_input = (padded_input @ H.mT)
            input_lr_padded = (H_input @ H) / size_m
            output = input_lr_padded[:, :, :, :d]
        elif first_transform == "noise":
            output_low_freq = fake_poweriteration_group(input_float.clone(), loop, d - rank, input.device, None, None)
            output = input_float - output_low_freq
        elif first_transform is None:
            output = fake_poweriteration_group(input_float.clone(), loop, rank, input.device, None, None)
        else:
            raise ValueError(f"Invalid first_transform: {first_transform}")
    else:
        raise ValueError(f"Invalid first_method: {first_method}")

    error = input_float - output
    
    if second_method == "quantization":
        if second_transform == "hadamard":
            error = gears_hadamard_channelQ(error.clone(), quantize_bit, group_size, sparsity) if input_tensor == "key" else gears_hadamard_tokenQ(error.clone(), quantize_bit, group_size, sparsity)
        elif second_transform is None:
            error = gears_channelQ(error.clone(), quantize_bit, group_size, sparsity) if input_tensor == "key" else gears_tokenQ(error.clone(), quantize_bit, group_size, sparsity)
        else:
            raise ValueError(f"Invalid second_transform: {second_transform}")
    elif second_method == "low_rank":
        if second_transform == "hadamard":
            size_m = next_power_of_2(d)
            pad_size = size_m - d
            padded_error = F.pad(error, (0, pad_size))
            H = torch.from_numpy(transform_cache.get_or_register("hadamard", d, r=rank)).to(input.device, dtype=input_float.dtype)
            H_error = (padded_error @ H)
            H_error_lr = fake_poweriteration_group(H_error, loop, rank, input.device, None, None)
            padded_error_lr = (H_error_lr @ H) / size_m
            error = padded_error_lr[:, :, :, :d]
        elif second_transform == "cov":
            try:
                if input_tensor == "key":
                    cov_matrix = input_float.mT @ input_float
                    jitter = 1e-5 * torch.eye(d, device=error.device, dtype=error.dtype)
                    cov_matrix += jitter

                    S = torch.linalg.cholesky(cov_matrix)
                    S_inv = torch.linalg.inv(S)
                    whitened_error = error @ S_inv
                    whitened_error_lr = fake_poweriteration_group(whitened_error, loop, rank, input.device, None, None)
                    error = whitened_error_lr @ S
                elif input_tensor == "value":
                    cov_matrix = input_float @ input_float.mT
                    jitter = 1e-5 * torch.eye(s, device=error.device, dtype=error.dtype)
                    cov_matrix += jitter

                    S = torch.linalg.cholesky(cov_matrix)
                    S_inv = torch.linalg.inv(S)
                    whitened_error = S_inv @ error
                    whitened_error_lr = fake_poweriteration_group(whitened_error, loop, rank, input.device, None, None)
                    error = S @ whitened_error_lr
            except torch._C._LinAlgError as e:
                print(f"compress_sandbox second_transform: Cholesky decomposition failed. Error: {e}")
                # Fallback to non-whitened LR approximation on error
                error = fake_poweriteration_group(error, loop, rank, input.device, None, None)
        elif second_transform == "outlier":
            error = outlier_removal_svd_channel(error.clone(), sparsity, loop, rank, transform="none") if input_tensor == "key" else outlier_removal_svd_token(error.clone(), sparsity, loop, rank, transform="none")
        elif second_transform == "hadamard_outlier":
            error = outlier_removal_svd_channel(error.clone(), sparsity, loop, rank, transform="hadamard") if input_tensor == "key" else outlier_removal_svd_token(error.clone(), sparsity, loop, rank, transform="hadamard")
        elif second_transform == "pca-basis-none":
            error = outlier_removal_svd_channel(error.clone(), sparsity, loop, rank, transform="pca-basis-none", group_size=group_size, lora_quant=True) if input_tensor == "key" else outlier_removal_svd_token(error.clone(), sparsity, loop, rank, transform="pca-basis-none", group_size=group_size, lora_quant=True)
        elif second_transform == "pca-basis-hadamard":
            error = outlier_removal_svd_channel(error.clone(), sparsity, loop, rank, transform="pca-basis-hadamard", group_size=group_size, lora_quant=True) if input_tensor == "key" else outlier_removal_svd_token(error.clone(), sparsity, loop, rank, transform="pca-basis-hadamard", group_size=group_size, lora_quant=True)
        elif second_transform == "happi":
            error = outlier_removal_svd_channel(error.clone(), sparsity, loop, rank, transform="happi", group_size=group_size, lora_quant=True) if input_tensor == "key" else outlier_removal_svd_token(error.clone(), sparsity, loop, rank, transform="happi", group_size=group_size, lora_quant=True)
        elif second_transform == "happi-v2":
            error = outlier_removal_svd_channel(error.clone(), sparsity, loop, rank, transform="happi-v2", group_size=group_size, lora_quant=True) if input_tensor == "key" else outlier_removal_svd_token(error.clone(), sparsity, loop, rank, transform="happi-v2", group_size=group_size, lora_quant=True)
        elif second_transform == "hla":
            size_m = next_power_of_2(d)
            pad_size = size_m - d
            padded_error = F.pad(error, (0, pad_size))
            H = torch.from_numpy(transform_cache.get_or_register("low_rank", d, r=hla_rank, freq="high")).to(input.device, dtype=input_float.dtype)
            H_error = (padded_error @ H.mT)
            padded_error_lr = (H_error @ H) / size_m
            error = padded_error_lr[:, :, :, :d]
        elif second_transform == "hla_outlier":
            error = outlier_removal_hla_channel(error, quantize_bit, group_size, sparsity, hla_rank) if input_tensor == "key" else outlier_removal_hla_token(error, quantize_bit, group_size, sparsity, hla_rank)
        elif second_transform == "noise":
            error_low_freq = fake_poweriteration_group(error.clone(), loop, d - rank, input.device, None, None)
            error = error - error_low_freq
        elif second_transform is None:
            error = fake_poweriteration_group(error, loop, rank, input.device, None, None)
        else:
            raise ValueError(f"Invalid second_transform: {second_transform}")
    else:
        raise ValueError(f"Invalid second_method: {second_method}")
    
    return (output + error).half()

def compress_insert_function(
    previous_key,
    previous_value,
    compress_config,
    layer_idx,
    pbase1=None,
    qbase1=None,
    pbase2=None,
    qbase2=None,
    prefill=None,
):
    batch, num_head, seq_len, sep_dim = previous_key.shape
    if compress_config.token_preserving[layer_idx] == True:
        starting_idx = int(compress_config.start_saving[layer_idx] * seq_len)
        locality_idx = int(compress_config.locality_saving[layer_idx] * seq_len)
    else:
        starting_idx = int(0)
        locality_idx = -seq_len
    # print("starting_idx:", starting_idx, "locality_idx:", locality_idx,compress_config.token_preserving[layer_idx],batch, num_head, seq_len, sep_dim)
    
    if compress_config.compress_method[layer_idx] == "KCVT":
        previous_key[:, :, starting_idx:-locality_idx, :] = fake_groupwise_channel_asymmetric_quantization(
            previous_key[:, :, starting_idx:-locality_idx, :],
            compress_config.quantize_bit[layer_idx],
            seq_len,
        )
        if previous_value is not None:
            previous_value[:, :, starting_idx:-locality_idx, :] = fake_groupwise_token_asymmetric_quantization(
                previous_value[:, :, starting_idx:-locality_idx, :],
                compress_config.quantize_bit[layer_idx],
                int(num_head * sep_dim),
            )

    if compress_config.compress_method[layer_idx] == "KIVI_V2":
        previous_key[:, :, starting_idx:-locality_idx, :] = fake_groupwise_channel_asymmetric_quantization(
            previous_key[:, :, starting_idx:-locality_idx, :],
            compress_config.quantize_bit[layer_idx],
            compress_config.group_size[layer_idx]
        )
        previous_value[:, :, starting_idx:-locality_idx, :] = fake_groupwise_token_asymmetric_quantization(
            previous_value[:, :, starting_idx:-locality_idx, :],
            compress_config.quantize_bit[layer_idx],
            compress_config.group_size[layer_idx]
        )

    if compress_config.compress_method[layer_idx] == "GEAR":
        prefill_rank = int(compress_config.prefill_rank[layer_idx])
        prefill_rankv = int(compress_config.prefill_rankv[layer_idx])
        rank = int(compress_config.rank[layer_idx])
        rankv = int(compress_config.rankv[layer_idx])
        if prefill is True:
            rank_used = prefill_rank
            rankv_used = prefill_rankv
        else:
            rank_used = rank
            rankv_used = rankv
        previous_key = gearslkivi_channelQ_new(
            previous_key,
            compress_config.quantize_bit[layer_idx],
            compress_config.group_size[layer_idx],
            compress_config.left[layer_idx],
            rank_used,
            compress_config.loop[layer_idx]
            
        )
        previous_key = previous_key.half()
        previous_value = gearslkivi_tokenQ_new(
            previous_value,
            compress_config.quantize_bit[layer_idx],
            compress_config.group_size[layer_idx],
            compress_config.left[layer_idx],
            rankv_used,
            compress_config.loop[layer_idx]
        )
        previous_value = previous_value.half()
    if compress_config.compress_method[layer_idx] == "GEAR-KCVT":
        prefill_rank = int(compress_config.prefill_rank[layer_idx])
        prefill_rankv = int(compress_config.prefill_rankv[layer_idx])
        rank = int(compress_config.rank[layer_idx])
        rankv = int(compress_config.rankv[layer_idx])
        if prefill is True:
            rank_used = prefill_rank
            rankv_used = prefill_rankv
        else:
            rank_used = rank
            rankv_used = rankv
        previous_key = gearslkivi_channelQ_new(
            previous_key,
            compress_config.quantize_bit[layer_idx],
            seq_len,
            compress_config.left[layer_idx],
            rank_used,
            compress_config.loop[layer_idx]
            
        )
        previous_key = previous_key.half()
        previous_value = gearslkivi_tokenQ_new(
            previous_value,
            compress_config.quantize_bit[layer_idx],
            int(num_head * sep_dim),
            compress_config.left[layer_idx],
            rankv_used,
            compress_config.loop[layer_idx]
        )
        previous_value = previous_value.half()
    if compress_config.compress_method[layer_idx] == "GEARL":

        prefill_rank = int(compress_config.prefill_rank[layer_idx])
        prefill_rankv = int(compress_config.prefill_rankv[layer_idx])
        rank = int(compress_config.rank[layer_idx])
        rankv = int(compress_config.rankv[layer_idx])
        if prefill is True:
            rank_used = prefill_rank
            rankv_used = prefill_rankv
        else:
            rank_used = rank
            rankv_used = rankv
        previous_key = tokenwise_gearlkivi_channelQ(
            previous_key,
            compress_config.quantize_bit[layer_idx],
            compress_config.group_size[layer_idx],
            rank_used,
            compress_config.loop[layer_idx],

            
        )
        previous_value = tokenwise_gearlkivi_tokenQ(
            previous_value,
            compress_config.quantize_bit[layer_idx],
            compress_config.group_size[layer_idx],
            rankv_used,
            compress_config.loop[layer_idx],
 
        )
    if compress_config.compress_method[layer_idx] == "GEARL-KCVT":
        prefill_rank = int(compress_config.prefill_rank[layer_idx])
        prefill_rankv = int(compress_config.prefill_rankv[layer_idx])
        rank = int(compress_config.rank[layer_idx])
        rankv = int(compress_config.rankv[layer_idx])
        if prefill is True:
            rank_used = prefill_rank
            rankv_used = prefill_rankv
        else:
            rank_used = rank
            rankv_used = rankv
        previous_key = tokenwise_gearlkivi_channelQ(
            previous_key,
            compress_config.quantize_bit[layer_idx],
            seq_len,
            rank_used,
            compress_config.loop[layer_idx],
            
            
        )
        previous_value = tokenwise_gearlkivi_tokenQ(
            previous_value,
            compress_config.quantize_bit[layer_idx],
            int(num_head * sep_dim),
            rankv_used,
            compress_config.loop[layer_idx],
            
        )
    
    if compress_config.compress_method[layer_idx] == "SANDBOX":
        prefill_rank = int(compress_config.prefill_rank[layer_idx])
        prefill_rankv = int(compress_config.prefill_rankv[layer_idx])
        rank = int(compress_config.rank[layer_idx])
        rankv = int(compress_config.rankv[layer_idx])
        if prefill is True:
            rank_used = prefill_rank
            rankv_used = prefill_rankv
        else:
            rank_used = rank
            rankv_used = rankv

        previous_key = compress_sandbox(
            previous_key,
            compress_config.quantize_bit[layer_idx],
            seq_len,
            compress_config.left[layer_idx],
            rank_used,
            compress_config.loop[layer_idx],
            first_method=compress_config.first_method_list[layer_idx],
            first_transform=compress_config.first_transform_list[layer_idx],
            second_method=compress_config.second_method_list[layer_idx],
            second_transform=compress_config.second_transform_list[layer_idx],
            input_tensor="key",
            hla_rank=compress_config.hla_rank_list[layer_idx],
        )
        previous_value = compress_sandbox(
            previous_value,
            compress_config.quantize_bit[layer_idx],
            int(num_head * sep_dim),
            compress_config.left[layer_idx],
            rankv_used,
            compress_config.loop[layer_idx],
            first_method=compress_config.first_method_list[layer_idx],
            first_transform=compress_config.first_transform_list[layer_idx],
            second_method=compress_config.second_method_list[layer_idx],
            second_transform=compress_config.second_transform_list[layer_idx],
            input_tensor="value",
            hla_rank=compress_config.hla_rank_list[layer_idx],
        )

    if compress_config.compress_method[layer_idx] == "PCC_COV_OUTLIER":
        prefill_rank = int(compress_config.prefill_rank[layer_idx])
        prefill_rankv = int(compress_config.prefill_rankv[layer_idx])
        rank = int(compress_config.rank[layer_idx])
        rankv = int(compress_config.rankv[layer_idx])
        if prefill is True:
            rank_used = prefill_rank
            rankv_used = prefill_rankv
        else:
            rank_used = rank
            rankv_used = rankv

        previous_key = compress_sandbox(
            previous_key,
            compress_config.quantize_bit[layer_idx],
            compress_config.group_size[layer_idx],
            compress_config.left[layer_idx],
            rank_used,
            compress_config.loop[layer_idx],
            first_method="low_rank",
            first_transform="pca-basis-hadamard",
            second_method="quantization",
            second_transform=None,
            input_tensor="key",
            hla_rank=compress_config.hla_rank_list[layer_idx],
        )
        previous_value = compress_sandbox(
            previous_value,
            compress_config.quantize_bit[layer_idx],
            compress_config.group_size[layer_idx],
            compress_config.left[layer_idx],
            rankv_used,
            compress_config.loop[layer_idx],
            first_method="low_rank",
            first_transform="pca-basis-hadamard",
            second_method="quantization",
            second_transform=None,
            input_tensor="value",
            hla_rank=compress_config.hla_rank_list[layer_idx],
        )
    if compress_config.compress_method[layer_idx] == "PCC_COV_OUTLIER_COMPACT":
        prefill_rank = int(compress_config.prefill_rank[layer_idx])
        prefill_rankv = int(compress_config.prefill_rankv[layer_idx])
        rank = int(compress_config.rank[layer_idx])
        rankv = int(compress_config.rankv[layer_idx])
        if prefill is True:
            rank_used = prefill_rank
            rankv_used = prefill_rankv
        else:
            rank_used = rank
            rankv_used = rankv
        previous_key = pcc_channelQ(
            previous_key,
            compress_config.quantize_bit[layer_idx],
            compress_config.group_size[layer_idx],
            compress_config.left[layer_idx],
            rank_used,
            compress_config.loop[layer_idx],
            kv_transform=compress_config.kv_transform_list[layer_idx]
        )
        previous_value = pcc_tokenQ(
            previous_value,
            compress_config.quantize_bit[layer_idx],
            compress_config.group_size[layer_idx],
            compress_config.left[layer_idx],
            rankv_used,
            compress_config.loop[layer_idx],
            kv_transform=compress_config.kv_transform_list[layer_idx]
        )

    if compress_config.compress_method[layer_idx] == "PALU_50":
        prefill_rank = int(compress_config.prefill_rank[layer_idx])
        prefill_rankv = int(compress_config.prefill_rankv[layer_idx])
        rank = int(compress_config.rank[layer_idx])
        rankv = int(compress_config.rankv[layer_idx])
        if prefill is True:
            rank_used = prefill_rank
            rankv_used = prefill_rankv
        else:
            rank_used = rank
            rankv_used = rankv
        previous_key = palu_channelQ(
            previous_key,
            compress_config.quantize_bit[layer_idx],
            compress_config.group_size[layer_idx],
            compress_config.left[layer_idx],
            rank_used,
            compress_config.loop[layer_idx]
        )
        previous_value = palu_tokenQ(
            previous_value,
            compress_config.quantize_bit[layer_idx],
            compress_config.group_size[layer_idx],
            compress_config.left[layer_idx],
            rankv_used,
            compress_config.loop[layer_idx]
        )

    return previous_key, previous_value