from importlib.resources import open_binary
from numpy import mean
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from .kmeans_utils import apply_codebook_quantization
index = 0

# Initialize with default values
kmeans_config = {
    'vector_length': 8,  # Default value
    'num_centroids': 256,  # Default value
    'max_iter': 2,  # Default value
    'use_last_iter_quantization': False  # Default value
}

# Function to update config with args
def update_kmeans_config(args):
    global kmeans_config
    kmeans_config = {
        'vector_length': args.vector_length,
        'num_centroids': args.num_centroids,
        'max_iter': args.max_iter,
        'use_last_iter_quantization': args.use_last_iter_quantization
    }


@torch.no_grad()
def part_mean(tensor, op='-'):
    non_zero = tensor*(tensor!=0)

    mean_val = non_zero.mean(-1).view(-1, 1)

    return mean_val

@torch.no_grad()
def high_order_residual(x, mask, order=2):
    sum_order = torch.zeros_like(x)
    new_matrix = x.clone()
    new_matrix = new_matrix * mask
    global index
    index += 1
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))

        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        masked_x_tensor -= mean_tensor_all[:, None]
        scale_tensor_all = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all = torch.where(torch.isnan(scale_tensor_all), torch.zeros_like(scale_tensor_all), scale_tensor_all)

        binary= torch.sign(masked_x_tensor)
        binary *= scale_tensor_all[:, None]
        binary += mean_tensor_all[:, None]
        sum_order = sum_order + binary*mask
    
    return sum_order

@torch.no_grad()
def high_order_residual_rc(x, mask, order=2):
    sum_order = torch.zeros_like(x)
    new_matrix = x.clone()
    new_matrix = new_matrix * mask
    global index
    index += 1
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))

        # mean row
        mean_tensor_all_r = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all_r = torch.where(torch.isnan(mean_tensor_all_r), torch.zeros_like(mean_tensor_all_r), mean_tensor_all_r)
        masked_x_tensor -= mean_tensor_all_r[:, None]
        # mean column
        mean_tensor_all_c = torch.nanmean(masked_x_tensor, dim=0)
        mean_tensor_all_c = torch.where(torch.isnan(mean_tensor_all_c), torch.zeros_like(mean_tensor_all_c), mean_tensor_all_c)
        masked_x_tensor -= mean_tensor_all_c[None, :]

        # alpha row
        scale_tensor_all_r = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all_r = torch.where(torch.isnan(scale_tensor_all_r), torch.zeros_like(scale_tensor_all_r), scale_tensor_all_r)
        # alpha column
        scale_tensor_all_c = torch.nanmean(torch.abs(masked_x_tensor / scale_tensor_all_r[:, None]), dim=0)
        scale_tensor_all_c = torch.where(torch.isnan(scale_tensor_all_c), torch.zeros_like(scale_tensor_all_c), scale_tensor_all_c)

        binary= torch.sign(masked_x_tensor)
        binary *= scale_tensor_all_r[:, None]
        binary *= scale_tensor_all_c[None, :]
        binary += mean_tensor_all_r[:, None] + mean_tensor_all_c[None, :]
        sum_order = sum_order + binary*mask

    return sum_order





@torch.no_grad()
def high_order_residual_alternating_order1_hessian_vq(x, mask, order=2, iter=15, enable_kmeans=True, H=None):
    sum_order = torch.zeros_like(x)
    new_matrix = x.clone()
    new_matrix = new_matrix * mask
    global index
    index += 1
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))

        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        masked_x_tensor -= mean_tensor_all[:, None]
        scale_tensor_all = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all = torch.where(torch.isnan(scale_tensor_all), torch.zeros_like(scale_tensor_all), scale_tensor_all)

        binary= torch.sign(masked_x_tensor)
        new_binary = binary.clone()
        binary *= scale_tensor_all[:, None]
        binary += mean_tensor_all[:, None]
        sum_order = sum_order + binary*mask

    # Alternating update
    refine_mean = mean_tensor_all.clone()
    sum_order_alternating = sum_order.clone()

    for k in range(iter):
        # 1. Fix alpha and B, update mean
        residual = new_matrix - sum_order_alternating
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        refine_mean += mean_tensor_all.clone()
        
        # 2. Fix mean and B, update alpha
        new_alpha = 1. / (torch.sum(new_binary * mask * new_binary * mask, dim=1) + 1e-6) * torch.sum(new_binary * mask * (new_matrix - refine_mean[:, None] * mask), dim=1)
        
        # 3. Fix mean and alpha, update B
        new_binary = torch.sign(new_matrix - refine_mean[:, None] * mask)

        if enable_kmeans:
            if torch.all(new_binary == 0):
                pass
            else:
                if kmeans_config['use_last_iter_quantization']:
                    if k == iter - 1:  # 在最后一次迭代应用量化
                        quant_result = apply_codebook_quantization(
                            new_binary, new_matrix, mask, refine_mean, new_alpha,
                            vector_length=kmeans_config['vector_length'], num_centroids=kmeans_config['num_centroids'], max_iter=kmeans_config['max_iter']
                        )
                        new_binary = quant_result['quantized_binary']
                        # 可以保存codebook和索引供推理时使用
                        codebook = quant_result['codebook']
                        indices = quant_result['indices']
                else:
                    quant_result = apply_codebook_quantization(
                        new_binary, new_matrix, mask, refine_mean, new_alpha,
                        vector_length=kmeans_config['vector_length'], num_centroids=kmeans_config['num_centroids'], max_iter=kmeans_config['max_iter']
                    )
                    new_binary = quant_result['quantized_binary']
                    # 可以保存codebook和索引供推理时使用
                    codebook = quant_result['codebook']
                    indices = quant_result['indices']

        # Final refine results
        sum_order_alternating = torch.zeros_like(x) + (new_alpha[:, None] * new_binary + refine_mean[:, None]) * mask
    

    return sum_order_alternating




@torch.no_grad()
def high_order_residual_alternating_order1(x, mask, order=2, iter=15, enable_kmeans=True):
    sum_order = torch.zeros_like(x)
    new_matrix = x.clone()
    new_matrix = new_matrix * mask
    global index
    index += 1
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))

        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        masked_x_tensor -= mean_tensor_all[:, None]
        scale_tensor_all = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all = torch.where(torch.isnan(scale_tensor_all), torch.zeros_like(scale_tensor_all), scale_tensor_all)

        binary= torch.sign(masked_x_tensor)
        new_binary = binary.clone()
        binary *= scale_tensor_all[:, None]
        binary += mean_tensor_all[:, None]
        sum_order = sum_order + binary*mask

    # Alternating update
    refine_mean = mean_tensor_all.clone()
    sum_order_alternating = sum_order.clone()

    for k in range(iter):
        # 1. Fix alpha and B, update mean
        residual = new_matrix - sum_order_alternating
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        refine_mean += mean_tensor_all.clone()
        
        # 2. Fix mean and B, update alpha
        new_alpha = 1. / (torch.sum(new_binary * mask * new_binary * mask, dim=1) + 1e-6) * torch.sum(new_binary * mask * (new_matrix - refine_mean[:, None] * mask), dim=1)
        
        # 3. Fix mean and alpha, update B
        new_binary = torch.sign(new_matrix - refine_mean[:, None] * mask)

        if enable_kmeans:
            if torch.all(new_binary == 0):
                pass
            else:
                if kmeans_config['use_last_iter_quantization']:
                    if k == iter - 1:  # 在最后一次迭代应用量化
                        quant_result = apply_codebook_quantization(
                            new_binary, new_matrix, mask, refine_mean, new_alpha,
                            vector_length=kmeans_config['vector_length'], num_centroids=kmeans_config['num_centroids'], max_iter=kmeans_config['max_iter']
                        )
                        new_binary = quant_result['quantized_binary']
                        # 可以保存codebook和索引供推理时使用
                        # codebook = quant_result['codebook']
                        # indices = quant_result['indices']
                else:
                    quant_result = apply_codebook_quantization(
                        new_binary, new_matrix, mask, refine_mean, new_alpha,
                        vector_length=kmeans_config['vector_length'], num_centroids=kmeans_config['num_centroids'], max_iter=kmeans_config['max_iter']
                    )
                    new_binary = quant_result['quantized_binary']
                    # 可以保存codebook和索引供推理时使用
                    # codebook = quant_result['codebook']
                    # indices = quant_result['indices']

        # Final refine results
        sum_order_alternating = torch.zeros_like(x) + (new_alpha[:, None] * new_binary + refine_mean[:, None]) * mask
    

    return sum_order_alternating

@torch.no_grad()
def high_order_residual_alternating_order1_novq(x, mask, order=2, iter=15):
    sum_order = torch.zeros_like(x)
    new_matrix = x.clone()
    new_matrix = new_matrix * mask
    global index
    index += 1
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))

        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        masked_x_tensor -= mean_tensor_all[:, None]
        scale_tensor_all = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all = torch.where(torch.isnan(scale_tensor_all), torch.zeros_like(scale_tensor_all), scale_tensor_all)

        binary= torch.sign(masked_x_tensor)
        new_binary = binary.clone()
        binary *= scale_tensor_all[:, None]
        binary += mean_tensor_all[:, None]
        sum_order = sum_order + binary*mask

    # Alternating update
    refine_mean = mean_tensor_all.clone()
    sum_order_alternating = sum_order.clone()

    for k in range(iter):
        # 1. Fix alpha and B, update mean
        residual = new_matrix - sum_order_alternating
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        refine_mean += mean_tensor_all.clone()
        
        # 2. Fix mean and B, update alpha
        new_alpha = 1. / (torch.sum(new_binary * mask * new_binary * mask, dim=1) + 1e-6) * torch.sum(new_binary * mask * (new_matrix - refine_mean[:, None] * mask), dim=1)
        
        # 3. Fix mean and alpha, update B
        new_binary = torch.sign(new_matrix - refine_mean[:, None] * mask)
        # Final refine results
        sum_order_alternating = torch.zeros_like(x) + (new_alpha[:, None] * new_binary + refine_mean[:, None]) * mask
    

    return sum_order_alternating

@torch.no_grad()
def high_order_residual_alternating_order1_x(x, mask, order=2, S=None, iter=15, iter2=15):
    sum_order = torch.zeros_like(x)
    new_matrix = x.clone()
    new_matrix = new_matrix * mask
    global index
    index += 1
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))

        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        masked_x_tensor -= mean_tensor_all[:, None]
        scale_tensor_all = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all = torch.where(torch.isnan(scale_tensor_all), torch.zeros_like(scale_tensor_all), scale_tensor_all)

        binary= torch.sign(masked_x_tensor)
        new_binary = binary.clone()
        binary *= scale_tensor_all[:, None]
        binary += mean_tensor_all[:, None]
        sum_order = sum_order + binary*mask

    # Alternating update
    refine_mean = mean_tensor_all.clone()
    sum_order_alternating = sum_order.clone()
    new_alpha = scale_tensor_all.clone()

    for k in range(iter):
        # 1. Fix alpha and B, update mean
        residual = new_matrix - sum_order_alternating
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        refine_mean += mean_tensor_all.clone()
        
        # 2. Fix mean and B, update alpha
        new_alpha = 1. / (torch.sum(new_binary * mask * new_binary * mask, dim=1) + 1e-6) * torch.sum(new_binary * mask * (new_matrix - refine_mean[:, None] * mask), dim=1)
        
        # 3. Fix mean and alpha, update B
        new_binary = torch.sign(new_matrix - refine_mean[:, None] * mask)

        # Final refine results
        sum_order_alternating = torch.zeros_like(x) + (new_alpha[:, None] * new_binary + refine_mean[:, None]) * mask

    MM = mask[:, :, None] * mask[:, None, :]
    refine_mean_den = torch.sum(S * MM, dim=(1,2), dtype=torch.bfloat16) + 1e-6
    masked_B = new_binary * mask
    new_alpha_den = torch.sum(S * masked_B[:, :, None] * masked_B[:, None, :], dim=(1,2)) + 1e-6
    # diag_S = torch.diag(S)
    for kk in range(iter2):
        # X error update mean
        refine_mean = torch.sum(S * (new_matrix - new_alpha[:, None] * new_binary * mask)[:, :, None] * MM, dim=(1,2)) / refine_mean_den

        # X error update alpha
        new_alpha = torch.sum(S * masked_B[:, :, None] * (new_matrix - refine_mean[:, None] * mask)[:, None, :], dim=(1,2)) / new_alpha_den

    sum_order_alternating = torch.zeros_like(x) + (new_alpha[:, None] * new_binary + refine_mean[:, None]) * mask

    return sum_order_alternating

@torch.no_grad()
def high_order_residual_alternating_order2_rc_nomean(x, mask, order=2, iter=15):
    sum_order = torch.zeros_like(x)
    new_matrix = x.clone()
    new_matrix = new_matrix * mask
    global index
    index += 1
    binary_list = []
    alpha_list_r = []
    alpha_list_c = []
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))

        # alpha row
        scale_tensor_all_r = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all_r = torch.where(torch.isnan(scale_tensor_all_r), torch.zeros_like(scale_tensor_all_r), scale_tensor_all_r)
        alpha_list_r.append(scale_tensor_all_r.clone())
        # alpha column
        scale_tensor_all_c = torch.nanmean(torch.abs(masked_x_tensor / scale_tensor_all_r[:, None]), dim=0)
        scale_tensor_all_c = torch.where(torch.isnan(scale_tensor_all_c), torch.zeros_like(scale_tensor_all_c), scale_tensor_all_c)
        alpha_list_c.append(scale_tensor_all_c.clone())

        binary= torch.sign(masked_x_tensor)
        binary_list.append(binary.clone())
        binary *= scale_tensor_all_r[:, None]
        binary *= scale_tensor_all_c[None, :]
        sum_order = sum_order + binary*mask

    # Alternating update
    sum_order_alternating = sum_order.clone()

    for k in range(iter):        
        # 2-1. Fix mean, alpha column, and B, update alpha row 0
        W_tilde = new_matrix - (alpha_list_c[1][None, :] * alpha_list_r[1][:, None] * binary_list[1]) * mask
        alpha_c_B = alpha_list_c[0][None, :] * binary_list[0] * mask
        alpha_list_r[0] = torch.sum(alpha_c_B * W_tilde, dim=1) / (torch.sum(alpha_c_B * alpha_c_B, dim=1) + 1e-6)
        
        # 2-2. Fix mean, alpha row, and B, update alpha column 0
        alpha_r_B =  alpha_list_r[0][:, None] * binary_list[0] * mask
        alpha_list_c[0] = torch.sum(alpha_r_B * W_tilde, dim=0) / (torch.sum(alpha_r_B * alpha_r_B, dim=0) + 1e-6)

        # 2-3. Fix mean, alpha column, and B, update alpha row 1
        W_tilde = new_matrix - (alpha_list_c[0][None, :] * alpha_list_r[0][:, None] * binary_list[0]) * mask
        alpha_c_B = alpha_list_c[1][None, :] * binary_list[1] * mask
        alpha_list_r[1] = torch.sum(alpha_c_B * W_tilde, dim=1) / (torch.sum(alpha_c_B * alpha_c_B, dim=1) + 1e-6)
        
        # 2-4. Fix mean, alpha row, and B, update alpha column 1
        alpha_r_B =  alpha_list_r[1][:, None] * binary_list[1] * mask
        alpha_list_c[1] = torch.sum(alpha_r_B * W_tilde, dim=0) / (torch.sum(alpha_r_B * alpha_r_B, dim=0) + 1e-6)

        # 3. Fix mean and alpha, update B
        new_matrix_expanded = new_matrix.unsqueeze(-1)
        comb0 = alpha_list_r[0].reshape(-1, 1) @ alpha_list_c[0].reshape(1, -1)
        comb1 = alpha_list_r[1].reshape(-1, 1) @ alpha_list_c[1].reshape(1, -1)
        v = torch.stack([-comb0 - comb1, -comb0 + comb1, 
                    comb0 - comb1, comb0 + comb1], dim=2)

        min_indices = torch.argmin(torch.abs(new_matrix_expanded - v), dim=-1)

        binary_list[0] = torch.ones_like(min_indices)
        binary_list[0][(min_indices == 0) | (min_indices == 1)] = -1
        binary_list[1] = torch.ones_like(min_indices)
        binary_list[1][(min_indices == 0) | (min_indices == 2)] = -1 

        # Final refine results
        sum_order_alternating = torch.zeros_like(x) + (alpha_list_c[0][None, :] * alpha_list_r[0][:, None] * binary_list[0] + alpha_list_c[1][None, :] * alpha_list_r[1][:, None] * binary_list[1]) * mask

    return sum_order_alternating



@torch.no_grad()
def high_order_residual_alternating_order2_rc_mean(x, mask, order=2, iter=15):
    sum_order = torch.zeros_like(x)
    new_matrix = x.clone()
    new_matrix = new_matrix * mask
    global index
    index += 1
    binary_list = []
    alpha_list_r = []
    alpha_list_c = []
    refine_mean = torch.zeros(x.shape[0], device=x.device)
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))

        # calculate mean
        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1) 
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        refine_mean += mean_tensor_all.clone()
        masked_x_tensor -= mean_tensor_all[:, None]

        # alpha row
        scale_tensor_all_r = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all_r = torch.where(torch.isnan(scale_tensor_all_r), torch.zeros_like(scale_tensor_all_r), scale_tensor_all_r)
        alpha_list_r.append(scale_tensor_all_r.clone())
        # alpha column
        scale_tensor_all_c = torch.nanmean(torch.abs(masked_x_tensor / scale_tensor_all_r[:, None]), dim=0)
        scale_tensor_all_c = torch.where(torch.isnan(scale_tensor_all_c), torch.zeros_like(scale_tensor_all_c), scale_tensor_all_c)
        alpha_list_c.append(scale_tensor_all_c.clone())

        binary= torch.sign(masked_x_tensor)
        binary_list.append(binary.clone())
        binary *= scale_tensor_all_r[:, None]
        binary *= scale_tensor_all_c[None, :]
        binary += mean_tensor_all[:, None]
        sum_order = sum_order + binary*mask

    # Alternating update
    sum_order_alternating = sum_order.clone()

    for k in range(iter):
        # 2-1. update mean
        residual = new_matrix - sum_order_alternating
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        refine_mean += mean_tensor_all.clone()


        # 2-2. Fix mean, alpha column, and B, update alpha row 0
        W_tilde = new_matrix - (alpha_list_c[1][None, :] * alpha_list_r[1][:, None] * binary_list[1]) * mask - refine_mean[:, None] * mask
        alpha_c_B = alpha_list_c[0][None, :] * binary_list[0] * mask
        alpha_list_r[0] = torch.sum(alpha_c_B * W_tilde, dim=1) / (torch.sum(alpha_c_B * alpha_c_B, dim=1) + 1e-6)
        
        # 2-3. Fix mean, alpha row, and B, update alpha column 0
        alpha_r_B =  alpha_list_r[0][:, None] * binary_list[0] * mask
        alpha_list_c[0] = torch.sum(alpha_r_B * W_tilde, dim=0) / (torch.sum(alpha_r_B * alpha_r_B, dim=0) + 1e-6)

        # 2-4. Fix mean, alpha column, and B, update alpha row 1
        W_tilde = new_matrix - (alpha_list_c[0][None, :] * alpha_list_r[0][:, None] * binary_list[0]) * mask - refine_mean[:, None] * mask
        alpha_c_B = alpha_list_c[1][None, :] * binary_list[1] * mask
        alpha_list_r[1] = torch.sum(alpha_c_B * W_tilde, dim=1) / (torch.sum(alpha_c_B * alpha_c_B, dim=1) + 1e-6)
        
        # 2-5. Fix mean, alpha row, and B, update alpha column 1
        alpha_r_B =  alpha_list_r[1][:, None] * binary_list[1] * mask
        alpha_list_c[1] = torch.sum(alpha_r_B * W_tilde, dim=0) / (torch.sum(alpha_r_B * alpha_r_B, dim=0) + 1e-6)

        # 3. Fix mean and alpha, update B
        new_matrix_expanded = new_matrix.unsqueeze(-1)
        comb0 = alpha_list_r[0].reshape(-1, 1) @ alpha_list_c[0].reshape(1, -1)
        comb1 = alpha_list_r[1].reshape(-1, 1) @ alpha_list_c[1].reshape(1, -1)
        v = torch.stack([-comb0 - comb1, -comb0 + comb1, 
                    comb0 - comb1, comb0 + comb1], dim=2)

        min_indices = torch.argmin(torch.abs(new_matrix_expanded - v), dim=-1)

        binary_list[0] = torch.ones_like(min_indices)
        binary_list[0][(min_indices == 0) | (min_indices == 1)] = -1
        binary_list[1] = torch.ones_like(min_indices)
        binary_list[1][(min_indices == 0) | (min_indices == 2)] = -1 

        # Final refine results
        sum_order_alternating = torch.zeros_like(x) + (alpha_list_c[0][None, :] * alpha_list_r[0][:, None] * binary_list[0] + alpha_list_c[1][None, :] * alpha_list_r[1][:, None] * binary_list[1] + refine_mean[:, None]) * mask

    return sum_order_alternating



@torch.no_grad()
def high_order_residual_alternating_order1_rc_nomean(x, mask, order=2, iter=15):
    sum_order = torch.zeros_like(x)
    new_matrix = x.clone()
    new_matrix = new_matrix * mask
    global index
    index += 1
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))

        # alpha row
        scale_tensor_all_r = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all_r = torch.where(torch.isnan(scale_tensor_all_r), torch.zeros_like(scale_tensor_all_r), scale_tensor_all_r)
        # alpha column
        scale_tensor_all_c = torch.nanmean(torch.abs(masked_x_tensor / scale_tensor_all_r[:, None]), dim=0)
        scale_tensor_all_c = torch.where(torch.isnan(scale_tensor_all_c), torch.zeros_like(scale_tensor_all_c), scale_tensor_all_c)

        binary= torch.sign(masked_x_tensor)
        new_binary = binary.clone()
        binary *= scale_tensor_all_r[:, None]
        binary *= scale_tensor_all_c[None, :]
        sum_order = sum_order + binary*mask

    # Alternating update
    sum_order_alternating = sum_order.clone()
    new_alpha_r = scale_tensor_all_r.clone()
    new_alpha_c = scale_tensor_all_c.clone()
    for k in range(iter):        
        # 1-1. Fix mean, alpha column, and B, update alpha row
        alpha_c_B = new_alpha_c[None, :] * new_binary * mask
        new_alpha_r = torch.sum(alpha_c_B * new_matrix, dim=1) / (torch.sum(alpha_c_B * alpha_c_B, dim=1) + 1e-6)
        
        # 1-2. Fix mean, alpha row, and B, update alpha column
        alpha_r_B = new_alpha_r[:, None] * new_binary * mask
        new_alpha_c = torch.sum(alpha_r_B * new_matrix, dim=0) / (torch.sum(alpha_r_B * alpha_r_B, dim=0) + 1e-6)

        # Final refine results
        sum_order_alternating = torch.zeros_like(x) + new_alpha_c[None, :] * new_alpha_r[:, None] * new_binary * mask

    return sum_order_alternating

@torch.no_grad()
def high_order_residual_alternating_order1_rc_mean(x, mask, order=2, iter=15):
    sum_order = torch.zeros_like(x)
    new_matrix = x.clone()
    new_matrix = new_matrix * mask
    global index
    index += 1
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))

        # calculate mean
        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        masked_x_tensor -= mean_tensor_all[:, None]

        # alpha row
        scale_tensor_all_r = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all_r = torch.where(torch.isnan(scale_tensor_all_r), torch.zeros_like(scale_tensor_all_r), scale_tensor_all_r)
        # alpha column
        scale_tensor_all_c = torch.nanmean(torch.abs(masked_x_tensor / scale_tensor_all_r[:, None]), dim=0)
        scale_tensor_all_c = torch.where(torch.isnan(scale_tensor_all_c), torch.zeros_like(scale_tensor_all_c), scale_tensor_all_c)

        binary= torch.sign(masked_x_tensor)
        new_binary = binary.clone()
        binary *= scale_tensor_all_r[:, None]
        binary *= scale_tensor_all_c[None, :]
        binary += mean_tensor_all[:, None]
        sum_order = sum_order + binary*mask

    # Alternating update
    refine_mean = mean_tensor_all.clone()
    sum_order_alternating = sum_order.clone()
    new_alpha_r = scale_tensor_all_r.clone()
    new_alpha_c = scale_tensor_all_c.clone()
    for k in range(iter):
        # 1-1. Fix alpha row, alpha column, and B, update mean
        residual = new_matrix - sum_order_alternating
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        refine_mean += mean_tensor_all.clone()
        

        # 1-2. Fix mean, alpha column, and B, update alpha row
        alpha_c_B = new_alpha_c[None, :] * new_binary * mask
        new_alpha_r = torch.sum(alpha_c_B * (new_matrix - refine_mean[:, None] * mask), dim=1) / (torch.sum(alpha_c_B * alpha_c_B, dim=1) + 1e-6)
        
        # 1-3. Fix mean, alpha row, and B, update alpha column
        alpha_r_B = new_alpha_r[:, None] * new_binary * mask
        new_alpha_c = torch.sum(alpha_r_B * (new_matrix - refine_mean[:, None] * mask), dim=0) / (torch.sum(alpha_r_B * alpha_r_B, dim=0) + 1e-6)

        # Final refine results
        sum_order_alternating = torch.zeros_like(x) + new_alpha_c[None, :] * new_alpha_r[:, None] * new_binary * mask

    return sum_order_alternating


@torch.no_grad()
def high_order_residual_alternating_mean(x, mask, order=2, num_iters=15, enable_kmeans=True):
    
    sum_order = torch.zeros_like(x)
    new_matrix = x.clone()
    new_matrix = new_matrix * mask
    global index
    index += 1
    binary_list = []
    alpha_list = []
    refine_mean = torch.zeros(x.shape[0], device=x.device)
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))

        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        refine_mean += mean_tensor_all.clone()
        masked_x_tensor -= mean_tensor_all[:, None]
        scale_tensor_all = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all = torch.where(torch.isnan(scale_tensor_all), torch.zeros_like(scale_tensor_all), scale_tensor_all)
        alpha_list.append(scale_tensor_all.clone())

        binary = torch.sign(masked_x_tensor)
        binary_list.append(binary.clone())
        binary *= scale_tensor_all[:, None]
        binary += mean_tensor_all[:, None]
        sum_order = sum_order + binary*mask

    new_matrix = x.clone() * mask
    sum_order_alternating = sum_order.clone()
    
    for k in range(num_iters):
        # 1. Fix alpha1, alpha2, B1, and B2, update mean
        residual = new_matrix - sum_order_alternating
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        refine_mean += mean_tensor_all.clone()

        # 2. Fix mean, B1, and B2, update alpha1 and alpha2
        alpha_list[0] = 1. / (torch.sum(binary_list[0] * mask * binary_list[0] * mask, dim=1) + 1e-6) * torch.sum(binary_list[0] * mask * (new_matrix - refine_mean[:, None] * mask - alpha_list[1][:, None] * binary_list[1] * mask), dim=1)
        alpha_list[1] = 1. / (torch.sum(binary_list[1] * mask * binary_list[1] * mask, dim=1) + 1e-6) * torch.sum(binary_list[1] * mask * (new_matrix - refine_mean[:, None] * mask - alpha_list[0][:, None] * binary_list[0] * mask), dim=1)

        # 3. Fix mean, alpha1, and alpha2, update B1 and B2
        new_matrix_expanded = (new_matrix - refine_mean[:, None] * mask).unsqueeze(-1)
        v = torch.stack([-alpha_list[0] - alpha_list[1], -alpha_list[0] + alpha_list[1], 
                    alpha_list[0] - alpha_list[1], alpha_list[0] + alpha_list[1]], dim=1).unsqueeze(1)

        min_indices = torch.argmin(torch.abs(new_matrix_expanded - v), dim=-1)

        binary_list[0] = torch.ones_like(min_indices)
        binary_list[0][(min_indices == 0) | (min_indices == 1)] = -1
        binary_list[1] = torch.ones_like(min_indices)
        binary_list[1][(min_indices == 0) | (min_indices == 2)] = -1 

        if enable_kmeans:
            if kmeans_config['use_last_iter_quantization']:
                if k == num_iters - 1:  # 在最后一次迭代应用量化
                    if torch.all(binary_list[0] == 0):
                        pass
                    else:
                        quant_result = apply_codebook_quantization(
                            binary_list[0], new_matrix - alpha_list[1][:, None] * binary_list[1], mask, refine_mean, alpha_list[0],
                        vector_length=kmeans_config['vector_length'], num_centroids=kmeans_config['num_centroids'], max_iter=kmeans_config['max_iter']
                    )
                        binary_list[0] = quant_result['quantized_binary']


                    if torch.all(binary_list[1] == 0):
                        pass
                    else:
                        quant_result = apply_codebook_quantization(
                            binary_list[1], new_matrix - alpha_list[0][:, None] * binary_list[0], mask, refine_mean, alpha_list[1],
                            vector_length=kmeans_config['vector_length'], num_centroids=kmeans_config['num_centroids'], max_iter=kmeans_config['max_iter']
                        )
                        binary_list[1] = quant_result['quantized_binary']

            else:
                if torch.all(binary_list[0] == 0):
                    pass
                else:
                    quant_result = apply_codebook_quantization(
                        binary_list[0], new_matrix - alpha_list[1][:, None] * binary_list[1], mask, refine_mean, alpha_list[0],
                        vector_length=kmeans_config['vector_length'], num_centroids=kmeans_config['num_centroids'], max_iter=kmeans_config['max_iter']
                    )
                    binary_list[0] = quant_result['quantized_binary']

                if torch.all(binary_list[1] == 0):
                    pass
                else:
                    quant_result = apply_codebook_quantization(
                        binary_list[1], new_matrix - alpha_list[0][:, None] * binary_list[0], mask, refine_mean, alpha_list[1],
                        vector_length=kmeans_config['vector_length'], num_centroids=kmeans_config['num_centroids'], max_iter=kmeans_config['max_iter']
                    )
                    binary_list[1] = quant_result['quantized_binary']

        sum_order_alternating = torch.zeros_like(x) + (alpha_list[0][:, None] * binary_list[0] + alpha_list[1][:, None] * binary_list[1] + refine_mean[:, None]) * mask

    return sum_order_alternating

@torch.no_grad()
def high_order_residual_alternating_mean_hessian_vq(x, mask, order=2, num_iters=15, enable_kmeans=True, H=None):
    
    sum_order = torch.zeros_like(x)
    new_matrix = x.clone()
    new_matrix = new_matrix * mask
    global index
    index += 1
    binary_list = []
    alpha_list = []
    refine_mean = torch.zeros(x.shape[0], device=x.device)
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))

        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        refine_mean += mean_tensor_all.clone()
        masked_x_tensor -= mean_tensor_all[:, None]
        scale_tensor_all = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all = torch.where(torch.isnan(scale_tensor_all), torch.zeros_like(scale_tensor_all), scale_tensor_all)
        alpha_list.append(scale_tensor_all.clone())

        binary = torch.sign(masked_x_tensor)
        binary_list.append(binary.clone())
        binary *= scale_tensor_all[:, None]
        binary += mean_tensor_all[:, None]
        sum_order = sum_order + binary*mask

    new_matrix = x.clone() * mask
    sum_order_alternating = sum_order.clone()
    
    for k in range(num_iters):
        # 1. Fix alpha1, alpha2, B1, and B2, update mean
        residual = new_matrix - sum_order_alternating
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        refine_mean += mean_tensor_all.clone()

        # 2. Fix mean, B1, and B2, update alpha1 and alpha2
        alpha_list[0] = 1. / (torch.sum(binary_list[0] * mask * binary_list[0] * mask, dim=1) + 1e-6) * torch.sum(binary_list[0] * mask * (new_matrix - refine_mean[:, None] * mask - alpha_list[1][:, None] * binary_list[1] * mask), dim=1)
        alpha_list[1] = 1. / (torch.sum(binary_list[1] * mask * binary_list[1] * mask, dim=1) + 1e-6) * torch.sum(binary_list[1] * mask * (new_matrix - refine_mean[:, None] * mask - alpha_list[0][:, None] * binary_list[0] * mask), dim=1)

        # 3. Fix mean, alpha1, and alpha2, update B1 and B2
        new_matrix_expanded = (new_matrix - refine_mean[:, None] * mask).unsqueeze(-1)
        v = torch.stack([-alpha_list[0] - alpha_list[1], -alpha_list[0] + alpha_list[1], 
                    alpha_list[0] - alpha_list[1], alpha_list[0] + alpha_list[1]], dim=1).unsqueeze(1)

        min_indices = torch.argmin(torch.abs(new_matrix_expanded - v), dim=-1)

        binary_list[0] = torch.ones_like(min_indices)
        binary_list[0][(min_indices == 0) | (min_indices == 1)] = -1
        binary_list[1] = torch.ones_like(min_indices)
        binary_list[1][(min_indices == 0) | (min_indices == 2)] = -1 

        if enable_kmeans:
            if kmeans_config['use_last_iter_quantization']:
                if k == num_iters - 1:  # 在最后一次迭代应用量化
                    if torch.all(binary_list[0] == 0):
                        pass
                    else:
                        quant_result = apply_codebook_quantization(
                            binary_list[0], new_matrix - alpha_list[1][:, None] * binary_list[1], mask, refine_mean, alpha_list[0],
                        vector_length=kmeans_config['vector_length'], num_centroids=kmeans_config['num_centroids'], max_iter=kmeans_config['max_iter']
                    )
                        binary_list[0] = quant_result['quantized_binary']


                    if torch.all(binary_list[1] == 0):
                        pass
                    else:
                        quant_result = apply_codebook_quantization(
                            binary_list[1], new_matrix - alpha_list[0][:, None] * binary_list[0], mask, refine_mean, alpha_list[1],
                            vector_length=kmeans_config['vector_length'], num_centroids=kmeans_config['num_centroids'], max_iter=kmeans_config['max_iter']
                        )
                        binary_list[1] = quant_result['quantized_binary']

            else:
                if torch.all(binary_list[0] == 0):
                    pass
                else:
                    quant_result = apply_codebook_quantization(
                        binary_list[0], new_matrix - alpha_list[1][:, None] * binary_list[1], mask, refine_mean, alpha_list[0],
                        vector_length=kmeans_config['vector_length'], num_centroids=kmeans_config['num_centroids'], max_iter=kmeans_config['max_iter']
                    )
                    binary_list[0] = quant_result['quantized_binary']

                if torch.all(binary_list[1] == 0):
                    pass
                else:
                    quant_result = apply_codebook_quantization(
                        binary_list[1], new_matrix - alpha_list[0][:, None] * binary_list[0], mask, refine_mean, alpha_list[1],
                        vector_length=kmeans_config['vector_length'], num_centroids=kmeans_config['num_centroids'], max_iter=kmeans_config['max_iter']
                    )
                    binary_list[1] = quant_result['quantized_binary']

        sum_order_alternating = torch.zeros_like(x) + (alpha_list[0][:, None] * binary_list[0] + alpha_list[1][:, None] * binary_list[1] + refine_mean[:, None]) * mask

    return sum_order_alternating

@torch.no_grad()
def high_order_residual_alternating_mean_novq(x, mask, order=2, num_iters=15):
    
    sum_order = torch.zeros_like(x)
    new_matrix = x.clone()
    new_matrix = new_matrix * mask
    global index
    index += 1
    binary_list = []
    alpha_list = []
    refine_mean = torch.zeros(x.shape[0], device=x.device)
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))

        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        refine_mean += mean_tensor_all.clone()
        masked_x_tensor -= mean_tensor_all[:, None]
        scale_tensor_all = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all = torch.where(torch.isnan(scale_tensor_all), torch.zeros_like(scale_tensor_all), scale_tensor_all)
        alpha_list.append(scale_tensor_all.clone())

        binary = torch.sign(masked_x_tensor)
        binary_list.append(binary.clone())
        binary *= scale_tensor_all[:, None]
        binary += mean_tensor_all[:, None]
        sum_order = sum_order + binary*mask

    new_matrix = x.clone() * mask
    sum_order_alternating = sum_order.clone()
    
    for k in range(num_iters):
        # 1. Fix alpha1, alpha2, B1, and B2, update mean
        residual = new_matrix - sum_order_alternating
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        refine_mean += mean_tensor_all.clone()

        # 2. Fix mean, B1, and B2, update alpha1 and alpha2
        alpha_list[0] = 1. / (torch.sum(binary_list[0] * mask * binary_list[0] * mask, dim=1) + 1e-6) * torch.sum(binary_list[0] * mask * (new_matrix - refine_mean[:, None] * mask - alpha_list[1][:, None] * binary_list[1] * mask), dim=1)
        alpha_list[1] = 1. / (torch.sum(binary_list[1] * mask * binary_list[1] * mask, dim=1) + 1e-6) * torch.sum(binary_list[1] * mask * (new_matrix - refine_mean[:, None] * mask - alpha_list[0][:, None] * binary_list[0] * mask), dim=1)

        # 3. Fix mean, alpha1, and alpha2, update B1 and B2
        new_matrix_expanded = (new_matrix - refine_mean[:, None] * mask).unsqueeze(-1)
        v = torch.stack([-alpha_list[0] - alpha_list[1], -alpha_list[0] + alpha_list[1], 
                    alpha_list[0] - alpha_list[1], alpha_list[0] + alpha_list[1]], dim=1).unsqueeze(1)

        min_indices = torch.argmin(torch.abs(new_matrix_expanded - v), dim=-1)

        binary_list[0] = torch.ones_like(min_indices)
        binary_list[0][(min_indices == 0) | (min_indices == 1)] = -1
        binary_list[1] = torch.ones_like(min_indices)
        binary_list[1][(min_indices == 0) | (min_indices == 2)] = -1 
        sum_order_alternating = torch.zeros_like(x) + (alpha_list[0][:, None] * binary_list[0] + alpha_list[1][:, None] * binary_list[1] + refine_mean[:, None]) * mask

    return sum_order_alternating



@torch.no_grad()
def high_order_residual_alternating_mean_x(x, mask, order=2, S=None, num_iters=15, iter2=15):
    sum_order = torch.zeros_like(x)
    new_matrix = x.clone()
    new_matrix = new_matrix * mask
    global index
    index += 1
    binary_list = []
    alpha_list = []
    refine_mean = torch.zeros(x.shape[0], device=x.device)
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))

        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        refine_mean += mean_tensor_all.clone()
        masked_x_tensor -= mean_tensor_all[:, None]
        scale_tensor_all = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all = torch.where(torch.isnan(scale_tensor_all), torch.zeros_like(scale_tensor_all), scale_tensor_all)
        alpha_list.append(scale_tensor_all.clone())

        binary = torch.sign(masked_x_tensor)
        binary_list.append(binary.clone())
        binary *= scale_tensor_all[:, None]
        binary += mean_tensor_all[:, None]
        sum_order = sum_order + binary*mask

    new_matrix = x.clone() * mask
    sum_order_alternating = sum_order.clone()
    
    for k in range(num_iters):
        # 1. Fix alpha1, alpha2, B1, and B2, update mean
        residual = new_matrix - sum_order_alternating
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        refine_mean += mean_tensor_all.clone()

        # 2. Fix mean, B1, and B2, update alpha1 and alpha2
        alpha_list[0] = 1. / (torch.sum(binary_list[0] * mask * binary_list[0] * mask, dim=1) + 1e-6) * torch.sum(binary_list[0] * mask * (new_matrix - refine_mean[:, None] * mask - alpha_list[1][:, None] * binary_list[1] * mask), dim=1)
        alpha_list[1] = 1. / (torch.sum(binary_list[1] * mask * binary_list[1] * mask, dim=1) + 1e-6) * torch.sum(binary_list[1] * mask * (new_matrix - refine_mean[:, None] * mask - alpha_list[0][:, None] * binary_list[0] * mask), dim=1)

        # 3. Fix mean, alpha1, and alpha2, update B1 and B2
        new_matrix_expanded = (new_matrix - refine_mean[:, None] * mask).unsqueeze(-1)
        v = torch.stack([-alpha_list[0] - alpha_list[1], -alpha_list[0] + alpha_list[1], 
                    alpha_list[0] - alpha_list[1], alpha_list[0] + alpha_list[1]], dim=1).unsqueeze(1)

        min_indices = torch.argmin(torch.abs(new_matrix_expanded - v), dim=-1)

        binary_list[0] = torch.ones_like(min_indices)
        binary_list[0][(min_indices == 0) | (min_indices == 1)] = -1
        binary_list[1] = torch.ones_like(min_indices)
        binary_list[1][(min_indices == 0) | (min_indices == 2)] = -1 

        sum_order_alternating = torch.zeros_like(x) + (alpha_list[0][:, None] * binary_list[0] + alpha_list[1][:, None] * binary_list[1] + refine_mean[:, None]) * mask

    MM = mask[:, :, None] * mask[:, None, :]
    refine_mean_den = torch.sum(S * MM, dim=(1,2)) + 1e-6
    masked_B0 = binary_list[0] * mask
    new_alpha0_den = torch.sum(S * masked_B0[:, :, None] * masked_B0[:, None, :], dim=(1,2)) + 1e-6
    masked_B1 = binary_list[1] * mask
    new_alpha1_den = torch.sum(S * masked_B1[:, :, None] * masked_B1[:, None, :], dim=(1,2)) + 1e-6
    for kk in range(iter2):
        # X error update mean
        refine_mean = torch.sum(S * (new_matrix - (alpha_list[0][:, None] * binary_list[0] + alpha_list[1][:, None] * binary_list[1]) * mask)[:, :, None] * MM, dim=(1,2)) / refine_mean_den

        # X error update alpha
        masked_W_mu = new_matrix - refine_mean[:, None] * mask
        alpha_list[0] = torch.sum(S * masked_B0[:, :, None] * (masked_W_mu[:, None, :] - (alpha_list[1][:, None] * masked_B1)[:, None, :]), dim=(1,2)) / new_alpha0_den
        alpha_list[1] = torch.sum(S * masked_B1[:, :, None] * (masked_W_mu[:, None, :] - (alpha_list[0][:, None] * masked_B0)[:, None, :]), dim=(1,2)) / new_alpha1_den

    sum_order_alternating = torch.zeros_like(x) + (alpha_list[0][:, None] * binary_list[0] + alpha_list[1][:, None] * binary_list[1] + refine_mean[:, None]) * mask

    return sum_order_alternating

@torch.no_grad()
def high_order_residual_alternating_order1_rc_x_nomean(x, mask, order=2, S=None, iter=15, iter2=15):
    sum_order = torch.zeros_like(x)
    new_matrix = (x.clone()).to(torch.float16)
    new_matrix = new_matrix * mask
    global index
    index += 1
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))

        # 计算并应用行均值（来自 order1_x）
        # mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        # mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        # masked_x_tensor -= mean_tensor_all[:, None]

        # 行列分解计算 alpha（来自 order1_rc_nomean）
        # alpha row
        scale_tensor_all_r = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all_r = torch.where(torch.isnan(scale_tensor_all_r), torch.zeros_like(scale_tensor_all_r), scale_tensor_all_r)
        # alpha column
        scale_tensor_all_c = torch.nanmean(torch.abs(masked_x_tensor / scale_tensor_all_r[:, None]), dim=0)
        scale_tensor_all_c = torch.where(torch.isnan(scale_tensor_all_c), torch.zeros_like(scale_tensor_all_c), scale_tensor_all_c)

        binary = torch.sign(masked_x_tensor)
        new_binary = binary.clone()
        binary *= scale_tensor_all_r[:, None]
        binary *= scale_tensor_all_c[None, :]
        # binary += mean_tensor_all[:, None]
        sum_order = sum_order + binary*mask

    # Alternating update
    # refine_mean = mean_tensor_all.clone()
    sum_order_alternating = sum_order.clone()
    new_alpha_r = scale_tensor_all_r.clone()
    new_alpha_c = scale_tensor_all_c.clone()

    for k in range(iter):
        # 1. Fix alpha and B, update mean
        # residual = new_matrix - sum_order_alternating
        # masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
        # mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        # mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        # refine_mean += mean_tensor_all.clone()
        
        # 2. 使用行列分解方法更新 alpha（来自 order1_rc_nomean）
        # 2-1. Fix mean, alpha column, and B, update alpha row
        alpha_c_B = new_alpha_c[None, :] * new_binary * mask
        # new_alpha_r = torch.sum(alpha_c_B * (new_matrix - refine_mean[:, None] * mask), dim=1) / (torch.sum(alpha_c_B * alpha_c_B, dim=1) + 1e-6)
        new_alpha_r = torch.sum(alpha_c_B * new_matrix, dim=1) / (torch.sum(alpha_c_B * alpha_c_B, dim=1) + 1e-6)

        
        # 2-2. Fix mean, alpha row, and B, update alpha column
        alpha_r_B = new_alpha_r[:, None] * new_binary * mask
        # new_alpha_c = torch.sum(alpha_r_B * (new_matrix - refine_mean[:, None] * mask), dim=0) / (torch.sum(alpha_r_B * alpha_r_B, dim=0) + 1e-6)
        new_alpha_c = torch.sum(alpha_r_B * new_matrix, dim=0) / (torch.sum(alpha_r_B * alpha_r_B, dim=0) + 1e-6)

        
        # 3. Fix mean and alpha, update B
        # new_binary = torch.sign(new_matrix - refine_mean[:, None] * mask)

        # Final refine results
        # sum_order_alternating = torch.zeros_like(x) + (new_alpha_c[None, :] * new_alpha_r[:, None] * new_binary + refine_mean[:, None]) * mask
        sum_order_alternating = torch.zeros_like(x) + new_alpha_c[None, :] * new_alpha_r[:, None] * new_binary * mask


    # 使用S参数进一步优化 - 类似于 order1_x 中的第二阶段优化
    # MM = mask[:, :, None] * mask[:, None, :]
    # refine_mean_den = torch.sum(S * MM, dim=(1,2), dtype=torch.bfloat16) + 1e-6
    # masked_B = new_binary * mask
    
    # # 为行列分解方法创建对应的分母
    # alpha_c_B = new_alpha_c[None, :] * masked_B  # B矩阵
    # alpha_r_B = new_alpha_r[:, None] * masked_B  # αr·B矩阵
    
    # new_alpha_r_den = torch.sum(S * alpha_c_B[:, :, None] * alpha_c_B[:, None, :], dim=(1,2)) + 1e-6
    # new_alpha_c_den = torch.sum(S * alpha_r_B[:, None, :] * alpha_r_B[:, :, None], dim=(1,2)) + 1e-6
    
    # for kk in range(iter2):
    #     # 优化均值
    #     refine_mean = torch.sum(S * (new_matrix - new_alpha_c[None, :] * new_alpha_r[:, None] * new_binary * mask)[:, :, None] * MM, dim=(1,2)) / refine_mean_den

    #     # 优化行 alpha
    #     masked_W_mu = new_matrix - refine_mean[:, None] * mask
    #     new_alpha_r = torch.sum(S * (new_alpha_c[None, :] * masked_B)[:, :, None] * masked_W_mu[:, None, :], dim=(1,2)) / new_alpha_r_den
        
    #     # 优化列 alpha
    #     new_alpha_c = torch.sum(S * (new_alpha_r[:, None] * masked_B)[:, :, None] * masked_W_mu[:, None, :], dim=(1,2)) / new_alpha_c_den

    # # 最终结果
    # sum_order_alternating = torch.zeros_like(x) + (new_alpha_c[None, :] * new_alpha_r[:, None] * new_binary + refine_mean[:, None]) * mask

    # return sum_order_alternating

    masked_B = new_binary * mask
    MM = mask[:, :, None] * mask[:, None, :]
    
    # Initial denominator terms with small epsilon for numerical stability
    refine_mean_den = torch.sum(S * MM, dim=(1,2), dtype=torch.bfloat16) + 1e-6
    
    
    for kk in range(iter2):
        pass
    # 最终结果
    sum_order_alternating = torch.zeros_like(x) + new_alpha_c[None, :] * new_alpha_r[:, None] * open_binary * mask

    return sum_order_alternating

@torch.no_grad()
def high_order_residual_alternating_order2_rc_x_nomean(x, mask, order=2, S=None, num_iters=15, iter2=15):
    sum_order = torch.zeros_like(x)
    new_matrix = x.clone()
    new_matrix = new_matrix * mask
    global index
    index += 1
    binary_list = []
    alpha_list_r = []
    alpha_list_c = []
    # refine_mean = torch.zeros(x.shape[0], device=x.device)
    
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))

        # 计算均值
        # mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        # mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        # refine_mean += mean_tensor_all.clone()
        # masked_x_tensor -= mean_tensor_all[:, None]
        
        # 行列分解计算 alpha (从 order2_rc_nomean 借鉴)
        # alpha row
        scale_tensor_all_r = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all_r = torch.where(torch.isnan(scale_tensor_all_r), torch.zeros_like(scale_tensor_all_r), scale_tensor_all_r)
        alpha_list_r.append(scale_tensor_all_r.clone())
        # alpha column
        scale_tensor_all_c = torch.nanmean(torch.abs(masked_x_tensor / scale_tensor_all_r[:, None]), dim=0)
        scale_tensor_all_c = torch.where(torch.isnan(scale_tensor_all_c), torch.zeros_like(scale_tensor_all_c), scale_tensor_all_c)
        alpha_list_c.append(scale_tensor_all_c.clone())

        binary = torch.sign(masked_x_tensor)
        binary_list.append(binary.clone())
        binary *= scale_tensor_all_r[:, None]
        binary *= scale_tensor_all_c[None, :]
        # binary += mean_tensor_all[:, None]
        sum_order = sum_order + binary*mask

    # 交替更新
    new_matrix = x.clone() * mask
    sum_order_alternating = sum_order.clone()
    
    for k in range(num_iters):
        # 1. 固定 alpha 和 B，更新 mean
        # residual = new_matrix - sum_order_alternating
        # masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
        # mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        # mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
        # refine_mean += mean_tensor_all.clone()
        
        # 2. 固定 mean 和 B，更新 alpha 行列分解 (从 order2_rc_nomean 借鉴)
        # 2-1. 更新 alpha_row 0

        # W_tilde = new_matrix - refine_mean[:, None] * mask - (alpha_list_c[1][None, :] * alpha_list_r[1][:, None] * binary_list[1]) * mask
        W_tilde = new_matrix - (alpha_list_c[1][None, :] * alpha_list_r[1][:, None] * binary_list[1]) * mask
        alpha_c_B = alpha_list_c[0][None, :] * binary_list[0] * mask
        alpha_list_r[0] = torch.sum(alpha_c_B * W_tilde, dim=1) / (torch.sum(alpha_c_B * alpha_c_B, dim=1) + 1e-6)
        
        # 2-2. 更新 alpha_column 0
        alpha_r_B = alpha_list_r[0][:, None] * binary_list[0] * mask
        alpha_list_c[0] = torch.sum(alpha_r_B * W_tilde, dim=0) / (torch.sum(alpha_r_B * alpha_r_B, dim=0) + 1e-6)
        
        # 2-3. 更新 alpha_row 1
        # W_tilde = new_matrix - refine_mean[:, None] * mask - (alpha_list_c[0][None, :] * alpha_list_r[0][:, None] * binary_list[0]) * mask
        W_tilde = new_matrix - (alpha_list_c[0][None, :] * alpha_list_r[0][:, None] * binary_list[0]) * mask
        alpha_c_B = alpha_list_c[1][None, :] * binary_list[1] * mask
        alpha_list_r[1] = torch.sum(alpha_c_B * W_tilde, dim=1) / (torch.sum(alpha_c_B * alpha_c_B, dim=1) + 1e-6)
        
        # 2-4. 更新 alpha_column 1
        alpha_r_B = alpha_list_r[1][:, None] * binary_list[1] * mask
        alpha_list_c[1] = torch.sum(alpha_r_B * W_tilde, dim=0) / (torch.sum(alpha_r_B * alpha_r_B, dim=0) + 1e-6)
        
        # 3. 固定 mean 和 alpha，更新 B
        # new_matrix_expanded = (new_matrix - refine_mean[:, None] * mask).unsqueeze(-1)
        new_matrix_expanded = (new_matrix).unsqueeze(-1)
        comb0 = alpha_list_r[0].reshape(-1, 1) @ alpha_list_c[0].reshape(1, -1)
        comb1 = alpha_list_r[1].reshape(-1, 1) @ alpha_list_c[1].reshape(1, -1)
        v = torch.stack([-comb0 - comb1, -comb0 + comb1, 
                    comb0 - comb1, comb0 + comb1], dim=2)
        
        min_indices = torch.argmin(torch.abs(new_matrix_expanded - v), dim=-1)
        
        binary_list[0] = torch.ones_like(min_indices)
        binary_list[0][(min_indices == 0) | (min_indices == 1)] = -1
        binary_list[1] = torch.ones_like(min_indices)
        binary_list[1][(min_indices == 0) | (min_indices == 2)] = -1
        
        # 更新结果
        # sum_order_alternating = torch.zeros_like(x) + (alpha_list_c[0][None, :] * alpha_list_r[0][:, None] * binary_list[0] + 
        #                                              alpha_list_c[1][None, :] * alpha_list_r[1][:, None] * binary_list[1] + 
        #                                              refine_mean[:, None]) * mask
        sum_order_alternating = torch.zeros_like(x) + (alpha_list_c[0][None, :] * alpha_list_r[0][:, None] * binary_list[0] + 
                                                     alpha_list_c[1][None, :] * alpha_list_r[1][:, None] * binary_list[1]) * mask
    
    # 使用 S 参数优化 (从 mean_x 借鉴)
    MM = mask[:, :, None] * mask[:, None, :]
    # refine_mean_den = torch.sum(S * MM, dim=(1,2), dtype=torch.bfloat16) + 1e-6
    
    # 行列分解优化准备
    masked_B0 = binary_list[0] * mask
    masked_B1 = binary_list[1] * mask



    for kk in range(iter2):
        # 优化均值
        # refine_mean = torch.sum(S * (new_matrix - ((alpha_list_c[0][None, :] * alpha_list_r[0][:, None] * binary_list[0]) + 
        #                                           (alpha_list_c[1][None, :] * alpha_list_r[1][:, None] * binary_list[1])) * mask)[:, :, None] * 
        #                        MM, dim=(1,2)) / refine_mean_den
        
        # 优化 alpha
        # masked_W_mu = new_matrix - refine_mean[:, None] * mask
        masked_W_mu = new_matrix
        W_tilde0 = masked_W_mu - (alpha_list_c[1][None, :] * alpha_list_r[1][:, None] * binary_list[1]) * mask
        WS = W_tilde0[:, :, None] * S

        alpha_r_num0 = torch.sum(WS * alpha_list_c[0][:, None, :] * masked_B0[:, None, :], dim=2)
        alpha_r_den0 = torch.sum(masked_B0[:, :, None] * masked_B0[:, None, :] * (alpha_list_c[0][None, :, None] * alpha_list_c[0][None, None, :]) * S[None, :, :], dim=(1,2)) + 1e-6
        alpha_list_r[0] = alpha_r_num0 / alpha_r_den0[:, None]

        alpha_c_num0 = torch.sum(WS.transpose(1, 2) * alpha_list_r[0][:, :, None] * masked_B0[:, :, None], dim=1)
        alpha_c_den0 = torch.sum(masked_B0[:, :, None] * masked_B0[:, None, :] * (alpha_list_r[0][:, :, None] ** 2) * S[None, :, :].transpose(1, 2), dim=(1,2)) + 1e-6
        alpha_list_c[0] = alpha_c_num0 / alpha_c_den0[:, None]

        W_tilde1 = masked_W_mu - (alpha_list_c[0][None, :] * alpha_list_r[0][:, None] * binary_list[0]) * mask
        WS = W_tilde1[:, :, None] * S

        alpha_r_num1 = torch.sum(WS * alpha_list_c[1][:, None, :] * masked_B1[:, None, :], dim=2)
        alpha_r_den1 = torch.sum(masked_B1[:, :, None] * masked_B1[:, None, :] * (alpha_list_c[1][None, :, None] * alpha_list_c[1][None, None, :]) * S[None, :, :], dim=(1,2)) + 1e-6
        alpha_list_r[1] = alpha_r_num1 / alpha_r_den1[:, None]
        
        alpha_c_num1 = torch.sum(WS.transpose(1, 2) * alpha_list_r[1][:, :, None] * masked_B1[:, :, None], dim=1)
        alpha_c_den1 = torch.sum(masked_B1[:, :, None] * masked_B1[:, None, :] * (alpha_list_r[1][:, :, None] ** 2) * S[None, :, :].transpose(1, 2), dim=(1,2)) + 1e-6
        alpha_list_c[1] = alpha_c_num1 / alpha_c_den1[:, None]

    sum_order_alternating = torch.zeros_like(x) + (alpha_list_c[0][None, :] * alpha_list_r[0][:, None] * binary_list[0] + 
                                                     alpha_list_c[1][None, :] * alpha_list_r[1][:, None] * binary_list[1]) * mask
    return sum_order_alternating
              

    # alpha_c_B0 = alpha_list_c[0][None, :] * masked_B0
    # alpha_r_B0 = alpha_list_r[0][:, None] * masked_B0
    # alpha_c_B1 = alpha_list_c[1][None, :] * masked_B1
    # alpha_r_B1 = alpha_list_r[1][:, None] * masked_B1
    
    # # 为行列分解准备分母
    # alpha_r0_den = torch.sum(S * (alpha_c_B0[:, :, None] * alpha_c_B0[:, None, :]), dim=(1,2)) + 1e-6
    # alpha_c0_den = torch.sum(S * (alpha_r_B0[:, None, :] * alpha_r_B0[:, :, None]), dim=(1,2)) + 1e-6
    # alpha_r1_den = torch.sum(S * (alpha_c_B1[:, :, None] * alpha_c_B1[:, None, :]), dim=(1,2)) + 1e-6
    # alpha_c1_den = torch.sum(S * (alpha_r_B1[:, None, :] * alpha_r_B1[:, :, None]), dim=(1,2)) + 1e-6
    
    # for kk in range(iter2):
    #     # 优化均值
    #     refine_mean = torch.sum(S * (new_matrix - ((alpha_list_c[0][None, :] * alpha_list_r[0][:, None] * binary_list[0]) + 
    #                                               (alpha_list_c[1][None, :] * alpha_list_r[1][:, None] * binary_list[1])) * mask)[:, :, None] * 
    #                            MM, dim=(1,2)) / refine_mean_den
        
    #     # 优化 alpha
    #     masked_W_mu = new_matrix - refine_mean[:, None] * mask
        
    #     # 更新 alpha_r0 (行)
    #     W_tilde0 = masked_W_mu - (alpha_list_c[1][None, :] * alpha_list_r[1][:, None] * binary_list[1]) * mask
    #     alpha_list_r[0] = torch.sum(S * (alpha_list_c[0][None, :] * masked_B0)[:, :, None] * 
    #                                W_tilde0[:, None, :], dim=(1,2)) / alpha_r0_den
        
    #     # 更新 alpha_c0 (列)
    #     alpha_list_c[0] = torch.sum(S * (alpha_list_r[0][:, None] * masked_B0)[:, :, None] * 
    #                                W_tilde0[:, None, :], dim=(1,2)) / alpha_c0_den
        
    #     # 更新 alpha_r1 (行)
    #     W_tilde1 = masked_W_mu - (alpha_list_c[0][None, :] * alpha_list_r[0][:, None] * binary_list[0]) * mask
    #     alpha_list_r[1] = torch.sum(S * (alpha_list_c[1][None, :] * masked_B1)[:, :, None] * 
    #                                W_tilde1[:, None, :], dim=(1,2)) / alpha_r1_den
        
    #     # 更新 alpha_c1 (列)
    #     alpha_list_c[1] = torch.sum(S * (alpha_list_r[1][:, None] * masked_B1)[:, :, None] * 
    #                                W_tilde1[:, None, :], dim=(1,2)) / alpha_c1_den
    
    # # 最终结果
    # sum_order_alternating = torch.zeros_like(x) + ((alpha_list_c[0][None, :] * alpha_list_r[0][:, None] * binary_list[0]) + 
    #                                              (alpha_list_c[1][None, :] * alpha_list_r[1][:, None] * binary_list[1]) + 
    #                                              refine_mean[:, None]) * mask
    
    # return sum_order_alternating


def high_order_residual_alternating_order2_rc_x():
    pass

def high_order_residual_alternating_order1_rc_x():
    pass


@torch.no_grad()
def normal_quantize(x, scale, zero, maxq):
    q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
    return scale * (q - zero)


class Binarization(nn.Module):
    def __init__(self, weight, method="arb", groupsize=-1):
        super().__init__()
        oc,ic=weight.shape
        if groupsize==-1:
            groupsize=ic
        self.groupsize=groupsize
        self.n_groups=math.ceil(ic/groupsize)
        self.method=method
        self.mean = 0

    def quantize(self, w, mask, order=2, groupi=0, S=None, H=None):
        if self.method=="xnor":
            w_mean = self.mean[groupi]
            w = w - w_mean  # oc, ic
            w = w.sign()
            w = w * self.scale[groupi]
            w+=w_mean
        elif self.method=="braq": # The method used in BiLLM
            w = high_order_residual(w, mask, order=order) 
        
        # arb series
        elif self.method == "arb":
            if order == 2:
                w = high_order_residual_alternating_mean(w, mask, order=order, enable_kmeans=True)  
            else:
                w = high_order_residual_alternating_order1(w, mask, order=order, enable_kmeans=True)  

        # arb series
        elif self.method == "arb-hessian_vq":
            if order == 2:
                w = high_order_residual_alternating_mean_hessian_vq(w, mask, order=order, enable_kmeans=True, H=H)  
            else:
                w = high_order_residual_alternating_order1_hessian_vq(w, mask, order=order, enable_kmeans=True, H=H)  

        elif self.method == 'arb-novq':
            if order == 2:
                w = high_order_residual_alternating_mean_novq(w, mask, order=order)  
            else:
                w = high_order_residual_alternating_order1_novq(w, mask, order=order)  
        elif self.method == 'arb-x':
            if order == 2:
                w = high_order_residual_alternating_mean_x(w, mask, order=order, S=S)  
            else:
                w = high_order_residual_alternating_order1_x(w, mask, order=order, S=S)  
        elif self.method == 'arb-rc':
            if order == 2:
                w = high_order_residual_alternating_order2_rc_nomean(w, mask, order=order)
            else:
                w = high_order_residual_alternating_order1_rc_nomean(w, mask, order=order)
        elif self.method == 'arb-rc-mean':
            if order == 2:
                w = high_order_residual_alternating_order2_rc_mean(w, mask, order=order)
            else:
                w = high_order_residual_alternating_order1_rc_mean(w, mask, order=order)
        # elif self.method == 'arb-rc-x':
        #     if order == 2:
        #         w = high_order_residual_alternating_order2_rc_x(w, mask, order=order, S=S)
        #     else:
        #         w = high_order_residual_alternating_order1_rc_x(w, mask, order=order, S=S)
        elif self.method == 'arb-rc-x-nomean':
            if order == 2:
                w = high_order_residual_alternating_order2_rc_x_nomean(w, mask, order=order, S=S)
            else:
                w = high_order_residual_alternating_order1_rc_x_nomean(w, mask, order=order, S=S)    



        elif self.method=="sign":
            w=(w>0).float()
            w*=self.scale[groupi]
        elif self.method=="rtn":
            w=F.relu(w)
            w_int=(w/self.scale[groupi]).round().clamp(0,1)
            w=w_int*self.scale[groupi]
        elif self.method in ['2bit','4bit']:

            bits = int(self.method[0])
            perchannel = True
            weight = True
            dev = w.device
            maxq = torch.tensor(2 ** bits - 1)
            scale = torch.zeros(1)
            zero = torch.zeros(1)

            if dev != scale.device:
                scale=scale.to(dev)
                zero=zero.to(dev)
                maxq=maxq.to(dev)

            x = w.clone()
            shape = x.shape

            if perchannel:
                if weight:
                    x = x.flatten(1)
                else:
                    if len(shape) == 4:
                        x = x.permute([1, 0, 2, 3])
                        x = x.flatten(1)
                    if len(shape) == 3:
                        x = x.reshape((-1, shape[-1])).t()
                    if len(shape) == 2:
                        x = x.t()
            else:
                x = x.flatten().unsqueeze(0)
            tmp = torch.zeros(x.shape[0], device=dev)
            xmin = torch.minimum(x.min(1)[0], tmp)
            xmax = torch.maximum(x.max(1)[0], tmp)

            tmp = (xmin == 0) & (xmax == 0)
            xmin[tmp] = -1
            xmax[tmp] = +1
            scale = (xmax - xmin) / maxq
            zero = torch.round(-xmin / scale)
            if not perchannel:
                if weight:
                    tmp = shape[0]
                else:
                    tmp = shape[1] if len(shape) != 3 else shape[2]
                scale = scale.repeat(tmp)
                zero = zero.repeat(tmp)

            if weight:
                shape = [-1] + [1] * (len(shape) - 1)
                scale = scale.reshape(shape)
                zero = zero.reshape(shape)
            w = normal_quantize(w, scale, zero, maxq)

        elif self.method=="prune":
            return torch.zeros_like(w)
        return w
