# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
import json
import time
import random
import torch
from concurrent.futures import ThreadPoolExecutor, as_completed
from memory_profiler import memory_usage  # 确保您已安装 memory_profiler
# 尝试导入模块
import json
import numpy as np
import time
import faiss
import torch
import math
import torch.distributed
from faiss.contrib.inspect_tools import get_additive_quantizer_codebooks
from src.data_utils import load_dataset, extract_weight_matrix
from multiprocessing import Pool, cpu_count
from memory_profiler import memory_usage
import random
import numpy as np
import torch
import time
import faiss
from sklearn.metrics import mean_squared_error

def fix_random_seed(seed):
    """Set random seed for reproducibility"""
    # Set Python built-in random seed
    random.seed(seed)

    # Set NumPy random seed
    np.random.seed(seed)

    # Set PyTorch random seed
    torch.random.manual_seed(seed)

    if torch.cuda.is_available():  # If GPU is available
        torch.cuda.manual_seed(seed)  # Set GPU random seed
        torch.cuda.manual_seed_all(seed)  # Set all GPUs random seed

        # Ensure deterministic behavior for CUDA
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False  # Disable adaptive algorithm selection for determinism


def generate_random_data():
    dataset_file = '.../Llama-3.1-8B_dataset.pth'  # Replaced personal path with placeholder

    # Load dataset
    dataset = load_dataset(dataset_file)

    if dataset is not None:
        # Specify the layer index and attribute to extract
        layer_index = 20  # Change as needed
        layer_attribute = 4  # Change as needed

        # Extract weight matrix
        weight_matrix = extract_weight_matrix(dataset, layer_index, layer_attribute)

        # Convert weight matrix to NumPy array
        weight_matrix_np = weight_matrix.detach().cpu().numpy()

        # Check original matrix shape
        original_shape = weight_matrix_np.shape
        print("Original weight matrix shape:", original_shape)

        # Check if original matrix has enough elements for reshaping
        if weight_matrix_np.size >= 4096 * 4096:
            # Reshape matrix to 524288 x 32 (65536*256 reshaped as 524288*32)
            weight_matrix_np = weight_matrix_np.flatten()[:524288 * 32].reshape(524288, 32)
            print("Reshaped weight matrix to", weight_matrix_np.shape)
        else:
            raise ValueError("The original weight matrix does not have enough elements to reshape to 65536x256.")

        # Scale each value by 1000
        weight_matrix_np = weight_matrix_np * 1e3

        # Save as .npy file
        np.save("random_data.npy", weight_matrix_np)
        print("Weight matrix of shape", weight_matrix_np.shape, "saved to 'random_data.npy'")
        return weight_matrix


def train_rq(M, K, rq_beam_size, xt, xval):
    """Train residual quantizer and return its codebook
    Args:
        xt: numpy array of shape (n, d)

    """
    nbit = int(np.log2(K))  # Calculate number of bits
    print(f"training RQ {M}x{nbit}, beam_size={rq_beam_size}")
    t0 = time.time()

    # Create residual quantizer
    rq = faiss.ResidualQuantizer(xt.shape[1], M, nbit)
    rq.max_beam_size = rq_beam_size

    if xt.dtype != np.float32:
        xt = xt.astype(np.float32)
    if xval.dtype != np.float32:
        xval = xval.astype(np.float32)

    rq.train(xt)
    print(f"[{time.time() - t0:.2f} s] training done")

    MSE = mean_squared_error(rq.decode(rq.compute_codes(xt)), xt)
    MSE_val = mean_squared_error(rq.decode(rq.compute_codes(xval)), xval)
    print(f"train set MSE={MSE:g} validation MSE={MSE_val:g}")
    rq_centroids = np.array(get_additive_quantizer_codebooks(rq))
    print(f"RQ centroids size {rq_centroids.shape}")
    return rq_centroids, MSE


def compute_batch_distances(a, b):
    """
    Compute batch pairwise squared L2 distances between two tensors.

    Args:
        a (torch.Tensor): Shape [n, a, d]
        b (torch.Tensor): Shape [n, b, d]

    Returns:
        torch.Tensor: Shape [n, a, b]
    """
    anorms = (a ** 2).sum(-1)
    bnorms = (b ** 2).sum(-1)
    return anorms.unsqueeze(-1) + bnorms.unsqueeze(1) - 2 * torch.bmm(a, b.transpose(2, 1))


def assign_batch_multiple(x, zqs):
    """
    Assign a batch of vectors to the nearest quantization vectors in codebooks.

    Args:
        x (torch.Tensor): Shape [bs, d]
        zqs (torch.Tensor): All candidate quantization vectors per batch element, shape [bs, ksq, d]

    Returns:
        codes (torch.int64): Indices of selected quantization vectors per batch element, shape [bs]
        quantized (torch.Tensor): Selected quantization vectors per batch element, shape [bs, d]
    """
    bs, d = x.shape
    bs, K, d = zqs.shape

    L2distances = compute_batch_distances(x.unsqueeze(1), zqs).squeeze(1)  # [bs, ksq]
    idx = torch.argmin(L2distances, dim=1).unsqueeze(1)  # [bs,1]
    quantized = torch.gather(zqs, dim=1, index=idx.unsqueeze(-1).repeat(1, 1, d))
    return idx.squeeze(1), quantized.squeeze(1)


def generate_compressed_matrix(weight_matrix, q, r, c, b, d):
    """
    Compress the input weight matrix.

    Args:
        weight_matrix (torch.Tensor): Input weight matrix from a deep learning model.
        q (float): Scaling factor for values adjustment.
        r (int): Number of quantizers.
        c (int): Number of centroids per quantizer.
        b (int): Beam search size.
        d (int): Number of features.

    Returns:
        rq_centroids_tensor (torch.Tensor): Trained centroids tensor with shape (r, c, d).
        compressed_codes (torch.Tensor): Compressed codes, shape [n, r].
        original_shape (tuple): Original shape of the weight matrix.
        q (float): The scaling factor used.
        mse (float): Mean squared error after compression.
    """
    # Ensure input is 2D tensor
    if weight_matrix.ndim != 2:
        raise ValueError("Input must be a 2D tensor.")

    # Convert PyTorch tensor to NumPy array
    weight_matrix_np = weight_matrix.detach().cpu().numpy()

    # Print original matrix shape
    print("Original weight matrix shape:", weight_matrix_np.shape)
    num_samples, num_features = weight_matrix_np.shape
    total_elements = num_samples * num_features
    if total_elements % d != 0:
        raise ValueError("Total elements must be divisible by the feature dimension d.")

    # Calculate the new number of samples n
    n = total_elements // d

    # Reshape matrix to (n, d)
    weight_matrix_np = weight_matrix_np.flatten()[:total_elements].reshape(n, d)
    print("Reshaped weight matrix to", weight_matrix_np.shape)

    # Scale matrix values
    matrix = weight_matrix_np * (10 ** q)

    # Reshape to (n, d)
    xt = matrix.reshape(n, d)

    # Ensure xt is a NumPy array
    if not isinstance(xt, np.ndarray):
        xt = xt.numpy()

    # Randomly select 64 indices for validation data
    xval_indices = random.sample(range(n), 64)
    xval = xt[xval_indices]

    # Train residual quantizer
    rq_centroids, mse0 = train_rq(r, c, b, xt, xval)

    # Batch size
    bs = n
    compressed_codes = []
    compressed_matrix = []

    # Convert to PyTorch tensors
    xt_tensor = torch.tensor(xt, dtype=torch.float32)
    rq_centroids_tensor = torch.tensor(rq_centroids, dtype=torch.float32)

    # Initial input is xt_tensor
    current_input = xt_tensor

    # For each quantizer, assign and compute residuals
    for i in range(rq_centroids_tensor.size(0)):
        current_centroids = rq_centroids_tensor[i].unsqueeze(0).repeat(bs, 1, 1)  # [bs, c, d]

        codes, quantized = assign_batch_multiple(current_input, current_centroids)

        compressed_codes.append(codes.unsqueeze(1))  # [bs,1]

        if i < rq_centroids_tensor.size(0) - 1:
            residual = current_input - quantized
            current_input = residual
        else:
            compressed_matrix = quantized

    # Concatenate codes into matrix [n, r]
    compressed_codes = torch.cat(compressed_codes, dim=1)

    # Compute mean squared error
    mse = mean_squared_error(xt_tensor.numpy(), compressed_matrix.numpy())
    print("MSE after compression:", mse)

    return rq_centroids_tensor, compressed_codes.to(torch.uint8), weight_matrix.shape, q, mse

def generate_decompressed_matrix(rq_centroids, compressed_codes, weight_matrix_shape, q):
    """
    Decompression function: input codebook and compressed codes, output reconstructed matrix.

    Args:
        rq_centroids (torch.Tensor): Codebook of shape (r, c, d), where r is the number of quantizers,
                                    c is the number of centroids per quantizer, d is feature dimension.
        compressed_codes (torch.Tensor): Compressed codes of shape (n, r), where n is number of samples,
                                        r is number of quantizers.
        weight_matrix_shape (tuple): Original weight matrix shape to restore output matrix.
        q (float): Scaling factor, to scale back the reconstructed matrix (values divided by 10^q).

    Returns:
        decompressed_matrix (torch.Tensor): Reconstructed matrix of shape (n, d), same as original matrix.
    """

    # Ensure all inputs are on the same device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Move rq_centroids and compressed_codes to device
    rq_centroids = rq_centroids.to(device)
    compressed_codes = compressed_codes.to(device)

    # Get codebook shape
    r, c, d = rq_centroids.shape  # r: number of quantizers, c: centroids per quantizer, d: feature dim
    n = compressed_codes.shape[0]  # number of samples

    # Initialize reconstructed matrix with zeros, dtype float16
    decompressed_matrix = torch.zeros((n, d), dtype=torch.float16, device=device)

    # Perform decompression step by step
    for i in range(r):  # Iterate over quantizers
        current_centroids = rq_centroids[i].to(device)  # Current quantizer centroids, shape (c, d)
        codes = compressed_codes[:, i]  # Compressed codes for current quantizer, shape (n,)
        codes = codes.to(torch.int32)  # Convert to int32 for indexing
        quantized = current_centroids[codes]  # Quantized vectors, shape (n, d)
        decompressed_matrix += quantized  # Sum quantized vectors to reconstruct

    # Reshape reconstructed matrix to original shape
    decompressed_matrix = decompressed_matrix.view(weight_matrix_shape)

    # Scale back by dividing by 10^q
    decompressed_matrix = decompressed_matrix / (10 ** q)

    return decompressed_matrix



import torch

def generate_compressed_matrix_g(weight_matrix, q, r, c, b, d, g):
    """
    Compress input weight matrix, split by column blocks.

    Args:
        weight_matrix (torch.Tensor): Input weight matrix from deep learning model.
        q (float): Scaling factor for adjusting matrix values.
        r (int): Number of quantizers.
        c (int): Number of centroids per quantizer.
        b (int): Beam search size.
        d (int): Number of features.
        g (int): Number of column blocks to split into.

    Returns:
        rq_centroids_tensor_list (list): List of codebook tensors (r, c, d) for each block.
        compressed_codes (torch.Tensor): Compressed codes for all blocks, shape [n, r*g].
        original_shape (tuple): Original shape of weight matrix.
        q (float): Scaling factor.
        mse0 (float): Mean squared error averaged over all blocks.
    """

    # Ensure input is 2D tensor
    if weight_matrix.ndim != 2:
        raise ValueError("Input must be a 2D tensor.")

    # Convert to NumPy array
    weight_matrix_np = weight_matrix.detach().cpu().numpy()

    # Get original shape
    original_shape = weight_matrix_np.shape
    num_samples, num_features = original_shape

    # Check if number of features divisible by g
    if num_features % g != 0:
        raise ValueError("Number of features must be divisible by g.")

    # Initialize lists for results
    rq_centroids_tensor_list = []
    compressed_codes_list = []
    total_mse = 0

    g_size = num_features // g  # Number of columns per block

    for block_index in range(g):
        # Extract current block columns
        start_col = block_index * g_size
        end_col = start_col + g_size
        weight_block = weight_matrix[:, start_col:end_col]  # Current block weight matrix

        # Compress current block using generate_compressed_matrix
        rq_centroids_tensor, compressed_codes, _, _, mse_temp = generate_compressed_matrix(
            weight_block, q, r, c, b, d
        )

        # Collect results
        rq_centroids_tensor_list.append(rq_centroids_tensor)
        compressed_codes_list.append(compressed_codes)

        # Accumulate MSE
        total_mse += mse_temp

    # Concatenate compressed codes along columns
    compressed_codes = torch.cat(compressed_codes_list, dim=1)  # Shape: [n, r*g]

    # Average MSE over blocks
    mse0 = total_mse / g

    return rq_centroids_tensor_list, compressed_codes.to(torch.uint8), original_shape, q, mse0



def generate_decompressed_matrix_g(rq_centroids_list, compressed_codes, weight_matrix_shape, q, g):
    """
    Decompress matrix compressed by column-block splitting.

    Args:
        rq_centroids_list (list): List of codebook tensors (r, c, d) for each block.
        compressed_codes (torch.Tensor): Compressed codes of shape [n, r*g].
        weight_matrix_shape (tuple): Original shape of weight matrix.
        q (float): Scaling factor used during compression.
        g (int): Number of column blocks.

    Returns:
        decompressed_matrix (torch.Tensor): Reconstructed matrix matching original shape.
    """

    d = rq_centroids_list[0].shape[2]
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    rq_centroids = [centroid.to(device) for centroid in rq_centroids_list]

    # Initialize decompressed matrix
    decompressed_matrix = torch.zeros((weight_matrix_shape[0], weight_matrix_shape[1]),
                                     dtype=torch.float16, device=device)

    for i in range(g):
        current_centroids = rq_centroids[i]

        # Extract compressed codes for current block
        codes_g = compressed_codes[:, i * current_centroids.shape[0]:(i + 1) * current_centroids.shape[0]]

        # Current block shape (samples, block_columns)
        weight_matrix_shape_g = (weight_matrix_shape[0], weight_matrix_shape[1] // g)

        # Decompress current block
        decompressed_matrix_g = generate_decompressed_matrix(current_centroids, codes_g, weight_matrix_shape_g, q)

        # Insert decompressed block into full matrix
        start_col = i * (weight_matrix_shape[1] // g)
        end_col = (i + 1) * (weight_matrix_shape[1] // g)
        decompressed_matrix[:, start_col:end_col] += decompressed_matrix_g

    return decompressed_matrix



def compute_outcome(weight_matrix, r, c, b, q, d):
    """
    Compute outcome metrics including accuracy and compression ratio.

    Args:
        weight_matrix (torch.Tensor): Original weight matrix.
        r (int): Number of quantizers.
        c (int): Number of centroids per quantizer.
        b (int): Beam search size.
        q (float): Scaling factor.
        d (int): Number of features.

    Returns:
        outcome (float): Combined metric of accuracy and compression.
        outcome_accuracy (float): Accuracy metric (MSE scaled).
        outcome_compress (float): Compression ratio.
    """

    num_elements = weight_matrix.numel()
    element_size = weight_matrix.element_size()
    memory_size = num_elements * element_size

    if weight_matrix.ndim != 2:
        raise ValueError("Input must be a 2D tensor.")

    weight_matrix_np = weight_matrix.detach().cpu().numpy()
    print("Original weight matrix shape:", weight_matrix_np.shape)
    num_samples, num_features = weight_matrix_np.shape
    total_elements = num_samples * num_features
    print(f"Total elements: {total_elements}, d value: {d}")

    if total_elements % d != 0:
        raise ValueError("Total elements must be divisible by feature dimension d.")

    n = total_elements // d
    weight_matrix_np = weight_matrix_np.flatten()[:total_elements].reshape(n, d)
    print("Reshaped weight matrix to", weight_matrix_np.shape)

    matrix = weight_matrix_np * (10 ** q)
    xt = matrix.reshape(n, d)

    if not isinstance(xt, np.ndarray):
        xt = xt.numpy()

    xval_indices = random.sample(range(n), 16)
    xval = xt[xval_indices]

    rq_centroids, mse = train_rq(r, c, b, xt, xval)

    element_size = rq_centroids.itemsize
    codebook_size = rq_centroids.size * element_size

    compressed_size = n * rq_centroids[0].shape[0]

    outcome_accuracy = mse / (10 ** q) ** 2
    outcome_compress = (compressed_size + codebook_size) / memory_size

    a1 = 0.1
    a2 = 1

    outcome = a1 * outcome_accuracy + a2 * outcome_compress

    return outcome, outcome_accuracy, outcome_compress




def optimize_compression_params(weight_matrix: torch.Tensor, layer_index: int, layer_attribute: int, output_file: str, log_file: str) -> tuple:
    """
    Load existing logs and find best parameters based on minimum outcome_accuracy.
    If no suitable params found, returns None.

    Args:
        weight_matrix (torch.Tensor): Weight matrix for analysis.
        layer_index (int): Index of the layer.
        layer_attribute (int): Attribute index of the layer.
        output_file (str): Path to save best parameters (not used in this function).
        log_file (str): Path to save logs (not used in this function).

    Returns:
        tuple or None: Best parameters (q, r, c, b, d) or None if no valid params found.
    """
    # Initialize best params
    outcome_best = float('inf')
    q_best, r_best, c_best, b_best, d_best = None, None, None, None, None

    # Log file path (replaced personal path with placeholder)
    log_file_path = f".../params_log/log_data_layer_{layer_index}_attr_{layer_attribute}.json"

    # Check if log file exists and load
    if os.path.exists(log_file_path):
        with open(log_file_path, 'r') as file:
            data = json.load(file)

            # If data is list, iterate entries
            if isinstance(data, list):
                for entry in data:
                    if isinstance(entry, dict):
                        q = entry.get("q")
                        r = entry.get("r")
                        c = entry.get("c")
                        b = entry.get("b")
                        d = entry.get("d")
                        outcome_accuracy = entry.get("outcome_accuracy")

                        if outcome_accuracy is not None and outcome_accuracy < outcome_best:
                            outcome_best = outcome_accuracy
                            q_best, r_best, c_best, b_best, d_best = q, r, c, b, d
            # If data is dict directly
            elif isinstance(data, dict):
                q = data.get("q")
                r = data.get("r")
                c = data.get("c")
                b = data.get("b")
                d = data.get("d")
                outcome_accuracy = data.get("outcome_accuracy")

                if outcome_accuracy is not None and outcome_accuracy < outcome_best:
                    outcome_best = outcome_accuracy
                    q_best, r_best, c_best, b_best, d_best = q, r, c, b, d
            else:
                print("Data format is not as expected.")
    else:
        print(f"Log file {log_file_path} does not exist.")

    if q_best is None:
        print("No valid parameter combination found.")
        return None

    return q_best, r_best, c_best, b_best, d_best


def calculate_mse_with_numpy(original, decompressed):
    """
    Calculate mean squared error (MSE) using NumPy arrays.

    Args:
        original (torch.Tensor): Original tensor.
        decompressed (torch.Tensor): Decompressed tensor.

    Returns:
        float: MSE value.
    """
    original_numpy = original.cpu().detach().numpy()
    decompressed_numpy = decompressed.cpu().detach().numpy()

    mse = np.mean((original_numpy - decompressed_numpy) ** 2)
    return mse


# import torch

# def quantize_with_rq(xt, rq_centroids_tensor):
#     """
#     Quantize input data with residual quantizer.
#     :param xt: numpy.ndarray, input data shape (n, d).
#     :param rq_centroids_tensor: torch.Tensor, quantizer centroids shape (r, c, d).
#     :return: torch.Tensor, compressed codes shape [n, r].
#     """
#     # Convert input data to PyTorch tensor
#     xt_tensor = torch.tensor(xt, dtype=torch.float16)
#     print("xt_tensor.shape:", xt_tensor.shape)  # Print input shape

#     compressed_codes = []
#     n = xt_tensor.shape[0]  # Sample count
#     current_input = xt_tensor

#     for i in range(rq_centroids_tensor.size(0)):  # Iterate quantizers
#         current_centroids = rq_centroids_tensor[i].unsqueeze(0).repeat(n, 1, 1)  # [bs, 64, 32]
#         print("current_centroids.shape:", current_centroids.shape)  # Print centroid shape

#         codes, quantized = assign_batch_multiple(current_input, current_centroids)  # [n]
#         print("codes.shape:", codes.shape)  # Print codes shape
#         print("quantized.shape:", quantized.shape)  # Print quantized shape

#         compressed_codes.append(codes.unsqueeze(1))  # Reshape [bs, 1]
#         print("compressed_codes.shape:", compressed_codes[-1].shape)  # Print compressed code shape

#         if i < rq_centroids_tensor.size(0) - 1:
#             residual = current_input - quantized  # Compute residual
#             assert residual.shape == current_input.shape, \
#                 f"Expected residual shape: {current_input.shape}, got {residual.shape}"
#             current_input = residual

#     compressed_codes_tensor = torch.cat(compressed_codes, dim=1)  # Shape: [n, r]
#     print("shape of compressed_codes:", compressed_codes_tensor.shape)  # Print final compressed codes shape

#     return compressed_codes_tensor


def assign_batch_multiple(current_input, current_centroids):
    """
    Hypothetical quantization function implementation:
    Args:
        current_input (torch.Tensor): shape [batch_size, d]
        current_centroids (torch.Tensor): shape [1, c, d]

    Returns:
        codes (torch.Tensor): [batch_size] centroid indices per sample
        quantized (torch.Tensor): [batch_size, d] quantized vectors
    """
    # Compute distances with broadcasting; no data duplication needed
    # torch.cdist supports float16 but float32 may offer better precision
    distances = torch.cdist(current_input, current_centroids.squeeze(0))  # [batch_size, c]
    codes = torch.argmin(distances, dim=1)  # [batch_size]
    quantized = current_centroids[0, codes, :]  # Select centroids shape [batch_size, d]
    return codes, quantized


def quantize_with_rq(xt, rq_centroids_tensor, batch_size=65536):
    """
    Quantize input data with residual quantizer, supports batching to avoid memory overflow.

    Args:
        xt (np.ndarray): Input data shape (n, d).
        rq_centroids_tensor (torch.Tensor): Quantizer centroids shape (r, c, d).
        batch_size (int): Batch size for processing.

    Returns:
        torch.Tensor: Compressed codes shape [n, r].
    """
    device = rq_centroids_tensor.device if rq_centroids_tensor.is_cuda else torch.device('cpu')
    xt_tensor = torch.tensor(xt, dtype=torch.float16, device=device)
    print("xt_tensor.shape:", xt_tensor.shape)

    n = xt_tensor.shape[0]
    r = rq_centroids_tensor.size(0)
    compressed_codes = []

    current_input = xt_tensor

    for i in range(r):
        codes_batches = []
        quantized_batches = []

        current_centroids = rq_centroids_tensor[i].unsqueeze(0)  # [1, c, d]

        for start in range(0, n, batch_size):
            end = min(start + batch_size, n)
            batch_input = current_input[start:end]  # [batch_size, d]

            codes, quantized = assign_batch_multiple(batch_input, current_centroids)

            codes_batches.append(codes)
            quantized_batches.append(quantized)

        codes_cat = torch.cat(codes_batches, dim=0)  # [n]
        quantized_cat = torch.cat(quantized_batches, dim=0)  # [n, d]

        print(f"Layer {i}: codes.shape={codes_cat.shape}, quantized.shape={quantized_cat.shape}")

        compressed_codes.append(codes_cat.unsqueeze(1))  # [n, 1]

        if i < r - 1:
            residual = current_input - quantized_cat
            assert residual.shape == current_input.shape, \
                f"Residual shape mismatch: expected {current_input.shape}, got {residual.shape}"
            current_input = residual

    compressed_codes_tensor = torch.cat(compressed_codes, dim=1)  # [n, r]
    print("shape of compressed_codes:", compressed_codes_tensor.shape)

    return compressed_codes_tensor


import torch
import numpy as np
import math

def generate_compressed_dense_matrix_g(dense_matrix, q, r, c, b, d, g):
    """
    Define residual compression function.
    :param dense_matrix: Input dense matrix, e.g., shape torch.Size([44773095])
    :param q: int
        Scaling factor to adjust matrix values.
    :param r: int
        Number of quantizers.
    :param c: int
        Number of centroids per quantizer.
    :param b: int
        Beam search size.
    :param d: int
        Number of features.
    :param g: int
        Number of groups (no longer group size).
    :return: compress_code_list, rq_centroids_tensor_list, dense_matrix_shape
        compress_code_list: List of compressed codes
        rq_centroids_tensor_list: List of quantizer centroid tensors
        dense_matrix_shape: Shape of the original dense matrix
    """
    # Save original matrix shape
    dense_matrix_shape = dense_matrix.shape
    print("dense_matrix_device:", dense_matrix.device)  # Print the device info of original matrix

    # 1. Calculate number of samples and reshape matrix
    n_elements = dense_matrix.numel()
    n_samples = math.ceil(n_elements / d)

    # Calculate padding size if needed
    pad_size = n_samples * d - n_elements
    if pad_size > 0:
        # Pad with zeros
        padded_matrix = torch.cat([dense_matrix, torch.zeros(pad_size, dtype=dense_matrix.dtype).to(dense_matrix.device)])
    else:
        padded_matrix = dense_matrix

    # Convert to NumPy array and move to CPU
    xt_np = padded_matrix.cpu().numpy()
    # Reshape to (n_samples, d)
    xt = xt_np.reshape(n_samples, d)

    # 2. Group processing (g is the number of groups)
    group_size = math.ceil(n_samples / g)  # Calculate group size by number of groups
    compress_code_list = []
    rq_centroids_tensor_list = []

    for group_idx in range(g):  # Iterate over groups
        # Get current group start and end indices
        start_idx = group_idx * group_size
        end_idx = min((group_idx + 1) * group_size, n_samples)

        # Get current group data
        xt_g = xt[start_idx:end_idx, :]

        # If last group has fewer samples than group_size, pad with zeros
        if xt_g.shape[0] < group_size:
            padding = np.zeros((group_size - xt_g.shape[0], d), dtype=xt_g.dtype)
            xt_g = np.concatenate([xt_g, padding], axis=0)

        # 3. Randomly select 32 samples from current group as validation set
        if xt_g.shape[0] >= 32:
            val_indices = np.random.permutation(xt_g.shape[0])[:32]
            xval_g = xt_g[val_indices]
        else:
            # If samples less than 32, use all samples
            xval_g = xt_g.copy()

        # 4. Train residual quantizer
        rq_centroids, mse = train_rq(r, c, b, xt_g, xval_g)

        rq_centroids_tensor = torch.tensor(rq_centroids, dtype=torch.float16)

        # 5. Compress data with trained quantizer
        compressed_codes = quantize_with_rq(xt_g, rq_centroids_tensor)
        print("Compressed compressed_codes.shape:", compressed_codes.shape)  # Print compressed codes shape

        # Save results
        compress_code_list.append(compressed_codes)
        rq_centroids_tensor_list.append(rq_centroids_tensor)

    # Convert list of rq_centroids_tensor to a single tensor if possible
    if isinstance(rq_centroids_tensor_list, list):
        try:
            rq_centroids_tensor = torch.stack([torch.tensor(x) if not isinstance(x, torch.Tensor) else x for x in rq_centroids_tensor_list])
        except Exception as e:
            print("Error converting rq_centroids_list to tensor:", e)
            rq_centroids_tensor = torch.tensor(rq_centroids_tensor_list)
    else:
        rq_centroids_tensor = rq_centroids_tensor_list

    # Convert list of compress_code_list to a single tensor if possible
    if isinstance(compress_code_list, list):
        try:
            compress_code_tensor = torch.stack([torch.tensor(x) if not isinstance(x, torch.Tensor) else x for x in compress_code_list])
        except Exception as e:
            print("Error converting compress_code_list to tensor:", e)
            compress_code_tensor = torch.tensor(compress_code_list)
    else:
        compress_code_tensor = compress_code_list

    return compress_code_tensor, rq_centroids_tensor, dense_matrix_shape


# def residual_decompression(compress_code_list, rq_centroids_tensor_list, dense_matrix_shape):
#     # TODO: Convert to tensor operations for acceleration
#     """
#     Define residual decompression function
#     :param compress_code_list: List of compressed codes [group_num, n, r]
#     :param rq_centroids_tensor_list: List of quantizer centroid tensors [group_num, (r, c, d)]
#     :param dense_matrix_shape: Shape of original dense matrix (e.g., torch.Size([44773095]))
#     :return: re_dense_matrix: Decompressed dense matrix (same shape as original input)
#     """
#     # 1. Initialize list to store all decompressed data
#     all_reconstructed = []
#
#     # TODO: Move rq_centroids_tensor_list to GPU for calculation
#     rq_centroids_tensor_list = [code.to('cuda') for code in rq_centroids_tensor_list]  # Move each centroid to GPU
#     compress_code_list = [code.to('cuda') for code in compress_code_list] # Move each compressed code to GPU
#
#     # print("Length of compress_code_list:", len(compress_code_list))  # Print compress_code_list length
#     # print("Shape of compress_code_list[0]:", compress_code_list[0].shape)  # Print first compress_code shape
#     # print("Length of rq_centroids_tensor_list:", len(rq_centroids_tensor_list))  # Print centroid list length
#     # print("Shape of rq_centroids_tensor_list[0]:", rq_centroids_tensor_list[0].shape)  # Print centroid shape
#     # print("dense_matrix_shape:", dense_matrix_shape)  # Print original matrix shape
#
#     # print("Device of first compress_code_list element:", compress_code_list[0].device)
#     # print("Device of first rq_centroids_tensor_list element:", rq_centroids_tensor_list[0].device)  # Print device info
#     # 2. Decompress each group
#     for group_idx in range(len(compress_code_list)):
#         # Get compressed codes and quantizers of current group
#         compressed_codes = compress_code_list[group_idx]  # shape: [n, r]
#         # print("compressed_codes.shape:", compressed_codes.shape)  # e.g. torch.Size([25487, 3])
#         rq_centroids = rq_centroids_tensor_list[group_idx]  # shape: [r, c, d]
#         # print("rq_centroids.shape:", rq_centroids.shape)  # e.g. torch.Size([3, 256, 9])
#
#         # 3. Initialize reconstruction matrix (zeros)
#         n_samples = compressed_codes.shape[0]
#         d = rq_centroids.shape[2]
#         reconstructed = torch.zeros(n_samples, d, dtype=torch.float16, device='cuda')
#
#         # 4. Reconstruct stepwise (sum over quantizers)
#         for i in range(compressed_codes.shape[1]):  # Iterate over quantizers r
#
#             codes = compressed_codes[:, i]  # Codes for current quantizer [n]
#             centroids = rq_centroids[i]  # Remove first dim to get [c, d]
#             codes = codes.to(torch.int64)
#
#             # Gather quantized vectors using codes and sum
#             quantized = centroids[codes]  # [n, d]
#             reconstructed += quantized
#
#         # 5. Append to total list
#         all_reconstructed.append(reconstructed)
#
#     # 6. Concatenate all group reconstructions
#     full_reconstructed = torch.cat(all_reconstructed, dim=0)  # [total_n, d]
#
#     # 7. Restore original shape (remove padding)
#     original_length = dense_matrix_shape.numel()
#     re_dense_matrix = full_reconstructed.flatten()[:original_length]
#
#     # 8. Restore shape and type
#     re_dense_matrix = re_dense_matrix.reshape(dense_matrix_shape)
#     re_dense_matrix = re_dense_matrix.to(dtype=torch.float16)
#     flattened_dense = re_dense_matrix.view(-1)
#     return flattened_dense  # Return decompressed matrix


def residual_decompression(compress_code_list, rq_centroids_tensor_list, dense_matrix_shape):
    # 1. Move to CUDA
    rq_centroids_tensor_list = [code.to('cuda') for code in rq_centroids_tensor_list]
    compress_code_list = [code.to('cuda') for code in compress_code_list]

    all_reconstructed = []

    for group_idx in range(len(compress_code_list)):
        compressed_codes = compress_code_list[group_idx]  # [n, r]
        rq_centroids = rq_centroids_tensor_list[group_idx]  # [r, c, d]

        n, r = compressed_codes.shape
        _, c, d = rq_centroids.shape

        # Cast indexing tensor to int64 for gather
        codes = compressed_codes.to(torch.int64)  # [n, r]

        # rq_centroids: [r, c, d]
        # compressed_codes: [n, r]
        # Objective: For each sample and each r, index centroid c and sum d-dimensional vectors to get [n, d]

        # Split codes along r dimension
        codes_split = torch.unbind(codes, dim=1)  # length r, each [n]

        quantized_list = []
        for i in range(r):
            centroids_i = rq_centroids[i]  # [c, d]
            codes_i = codes_split[i]       # [n]

            # Use fancy indexing to select centroids by codes
            quantized_i = centroids_i[codes_i]  # [n, d]

            quantized_list.append(quantized_i)

        # Stack quantized tensors [r, n, d] -> permute to [n, r, d]
        quantized_stack = torch.stack(quantized_list, dim=0).permute(1, 0, 2)  # [n, r, d]

        # Sum over r dimension to reconstruct [n, d]
        reconstructed = quantized_stack.sum(dim=1).to(dtype=torch.float16)

        all_reconstructed.append(reconstructed)

    full_reconstructed = torch.cat(all_reconstructed, dim=0)  # [total_n, d]

    original_length = dense_matrix_shape.numel()
    re_dense_matrix = full_reconstructed.flatten()[:original_length].reshape(dense_matrix_shape)

    re_dense_matrix = re_dense_matrix.to(dtype=torch.float16)

    flattened_dense = re_dense_matrix.view(-1)
    return flattened_dense


def generate_sparse_tensor(shape, zero_percent=0.23):
    """
    Generate a sparse tensor with specified zero element ratio.
    """
    total_elements = math.prod(shape)
    zero_elements = int(total_elements * zero_percent)

    # Create random tensor
    dense_tensor = torch.randn(shape, dtype=torch.float16)
    dense_tensor = torch.clamp(dense_tensor, min=0, max=256)

    # Randomly select indices to zero out
    zero_indices = torch.randperm(total_elements)[:zero_elements]
    dense_tensor.view(-1)[zero_indices] = 0

    return dense_tensor


def test_residual_compression():
    # 1. Generate test data
    # Generate tensor with 23% zeros, shape (4096, 14336)
    rows, cols = 4096, 14336
    zero_percent = 0.23
    e = generate_sparse_tensor((rows, cols), zero_percent)

    # Create mask matrix to record zero/nonzero positions
    mask = (e != 0)
    print("mask data type:", mask.dtype)
    # Extract nonzero elements as dense matrix
    dense_matrix = e[mask]

    # 2. Set compression parameters
    q = 1    # Scaling factor (not used in this implementation)
    r = 1    # Number of quantizers
    c = 256  # Number of centroids per quantizer
    b = 1    # Beam search size (not used)
    g = 2    # Number of groups
    d = 16

    print(f"Original data shape: {dense_matrix.shape}")

    # 3. Perform compression
    compress_code_list, rq_centroids_list, original_shape = generate_compressed_dense_matrix_g(
        dense_matrix, q, r, c, b, d, g
    )

    # 4. Check compression results
    print(f"Number of compressed groups: {len(compress_code_list)}")
    print(f"First group code shape: {compress_code_list[0].shape}")
    print(f"First group quantizer shape: {rq_centroids_list[0].shape}")

    # 5. Perform decompression
    reconstructed = residual_decompression(
        compress_code_list, rq_centroids_list, original_shape
    )

    # 6. Validate results
    print(f"Reconstructed data shape: {reconstructed.shape}")

    # Calculate error
    error = torch.mean(torch.abs(dense_matrix - reconstructed))
    print(f"Mean absolute error: {error.item():.6f}")
    error_mse = torch.mean((dense_matrix - reconstructed) ** 2)
    print(f"Mean squared error: {error_mse.item():.6f}")

    # Check shape consistency
    assert reconstructed.shape == dense_matrix.shape
    print("Shape validation passed!")

    # Check data approximation (lossy compression, perfect match impossible)
    assert error < 0.1  # Adjust threshold as needed
    print("Data error within acceptable range!")


def adjust_matrix(w, d):
    """
    Adjust weight matrix to be divisible by d and record remainder.
    :param w: Original weight matrix (torch.Tensor)
    :param d: Block size (int)
    :return: dense_matrix (torch.Tensor), remainder (torch.Tensor or None)
    """
    total_elements = w.numel()
    padded_size = ((total_elements + d - 1) // d) * d
    padded_matrix = torch.zeros(padded_size, dtype=w.dtype)
    padded_matrix[:total_elements] = w.view(-1)
    print("padded_matrix shape:", padded_matrix.shape)

    if total_elements % d == 0:
        return padded_matrix.view(-1, d), None
    else:
        remainder = padded_matrix[total_elements:]
        print("remainder shape:", remainder.shape)
        return padded_matrix.view(-1, d), remainder


def restore_matrix(dense_matrix, e, original_shape, d):
    """
    Restore original matrix from dense blocks and remainder.
    :param dense_matrix: Blocked matrix (torch.Tensor)
    :param e: Remainder (torch.Tensor or None)
    :param original_shape: Original matrix shape (tuple)
    :param d: Block size (int)
    :return: Restored matrix (torch.Tensor)
    """
    if e is not None:
        e = e.to('cuda')

    flattened_matrix = dense_matrix.view(-1)
    if e is not None:
        flattened_matrix = torch.cat([flattened_matrix, e])

    restored_matrix = flattened_matrix[:original_shape[0] * original_shape[1]].view(original_shape)

    return restored_matrix


def generate_decompressed_dense_matrix_g(compress_code_tensor, rq_centroids_tensor, original_shape_e, e, layer_shape, d):
    rq_centroids_list_e = [rq_centroids_tensor[i] for i in range(rq_centroids_tensor.size(0))]
    compress_code_list_e = [compress_code_tensor[i] for i in range(compress_code_tensor.size(0))]

    reconstructed_dense = residual_decompression(compress_code_list_e, rq_centroids_list_e, original_shape_e)

    restored_w = restore_matrix(reconstructed_dense, e, layer_shape, d)

    return restored_w


# if __name__ == "__main__":
#     fix_random_seed(42)
#     test_residual_compression()
