# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import sys
import logging
import argparse
import random
import torch
# import resource
import torch.nn as nn
from datetime import datetime
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from transformers import  AutoTokenizer
from typing import Dict, Optional, List
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
from argparse import Namespace 
from concurrent.futures import ThreadPoolExecutor, as_completed
from src.activation_utils import get_inps_llama_by_linear
from src.data_utils import get_dataset_by_length
from src.quant_utils import generate_compressed_matrix,generate_decompressed_matrix,generate_compressed_dense_matrix_g,optimize_compression_params,fix_random_seed,adjust_matrix, generate_decompressed_dense_matrix_g
from src.model_utils import  load_safetensors_model,save_model
from src.eval_utils import eval,eval_only_ppl,eval_only_ppl_data
from src.activation_utils import get_inps_llama_by_linear,get_outs_llama_by_linear
from src.compensation_utils import  update_groupwise,compensation_groupwise
from src.activation_utils import get_inps_llama_by_linear
from src.data_utils import get_dataset_by_length
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


def get_Linears(model):
    linear_layers = []
    for name, layer in model.named_modules():
        if isinstance(layer, torch.nn.Linear):
            layer_index = None
            if 'layers.' in name:
                layer_index_str = name.split('layers.')[1].split('.')[0]
                layer_index = int(layer_index_str)

            print(
                f"Layer Index: {layer_index}, Layer Name: {name}, Input Features: {layer.in_features}, Output Features: {layer.out_features}")
            linear_layers.append((layer_index, name, layer))
    return linear_layers


def compute_quantization_mrtric(inps: torch.Tensor,
                               w: torch.Tensor,
                               weight1: torch.Tensor,
                               candidate: torch.Tensor) -> torch.Tensor:
    """
    Calculate the mean squared error (MSE) of outputs before and after quantization.

    Args:
        inps (torch.Tensor): Activation input tensor, shape (batch_size, input_dim).
        w (torch.Tensor): Original weight tensor, shape (input_dim, output_dim).
        weight1 (torch.Tensor): Quantized weight tensor, same shape as w.
        candidate (torch.Tensor): Compensation matrix, same shape as w.

    Returns:
        torch.Tensor: Scalar tensor representing MSE between quantized and original outputs.
    """
    # Quantized weight plus compensation
    quantized_weight = weight1 + candidate

    # Compute output before quantization
    output_before = torch.matmul(inps, w.T)  # (batch_size, output_dim)

    # Compute output after quantization
    output_after = torch.matmul(inps, quantized_weight.T)  # (batch_size, output_dim)

    # Calculate mean squared error
    mse = F.mse_loss(output_after, output_before)

    return mse


def find_top1percent_max_error_coords(
    output_e: torch.Tensor,
    mode: str = 'top'  # 'top' means select columns with highest sum; 'random' means random 1% columns
) -> List[int]:
    """
    Find indices of 1% columns that satisfy the selection criteria in output_e.

    Args:
        output_e (torch.Tensor): 2D tensor of shape (N, M).
        mode (str): Sampling mode, 'top' to select top 1% columns by sum, 'random' to randomly select 1% columns.

    Returns:
        List[int]: List of selected column indices (unique and sorted).
    """

    assert output_e.dim() == 2, "Input tensor must be 2D"
    assert mode in ['top', 'random'], "mode must be 'top' or 'random'"

    N, M = output_e.shape
    top_k = max(1, M // 100)  # At least one column (1% of M)

    if mode == 'top':
        # Sum per column
        col_sums = output_e.sum(dim=0)
        # Get indices of top_k largest sums
        topk_values, topk_indices = torch.topk(col_sums, top_k, largest=True)
        selected_cols = topk_indices
    else:  # mode == 'random'
        permuted_indices = torch.randperm(M, device=output_e.device)
        selected_cols = permuted_indices[:top_k]

    unique_cols = sorted(set(selected_cols.cpu().tolist()))
    return unique_cols


def sample_sparse_params_by_columns(
    matrix: torch.Tensor,
    rows: list,
    sample_ratio=0.01,
    mode='random'  # New parameter, either 'random' or 'top'
) -> torch.Tensor:
    """
    Sample elements from specified rows of matrix with a ratio sample_ratio.
    If mode='random', randomly sample sample_ratio proportion of elements;
    if mode='top', keep the top sample_ratio proportion by absolute value.
    Other elements are zeroed out.

    Args:
        matrix (torch.Tensor): Input 2D tensor.
        rows (list): List of row indices to sample from.
        sample_ratio (float): Sampling ratio in (0,1].
        mode (str): 'random' or 'top'.

    Returns:
        torch.Tensor: Sparse matrix with only sampled elements non-zero.
    """
    candidate = torch.zeros_like(matrix)
    print("candidate.shape", candidate.shape)

    num_cols = matrix.shape[1]

    for row in rows:
        row_data = matrix[row, :]
        num_elements = row_data.numel()
        num_samples = max(1, int(num_elements * sample_ratio))

        if mode == 'random':
            # Randomly sample indices without replacement
            sampled_indices = torch.randperm(num_elements)[:num_samples]

        elif mode == 'top':
            # Select indices of top absolute value elements
            _, indices = torch.topk(row_data.abs(), num_samples, largest=True, sorted=False)
            sampled_indices = indices

        else:
            raise ValueError(f"Unsupported mode: {mode}. Use 'random' or 'top'.")

        candidate[row, sampled_indices] = row_data[sampled_indices]

    return candidate


def extract_weight_params(weight: torch.Tensor, problem_channels: list, dim: int = 1):
    """
    Extract parameters of specified channels from weight tensor,
    returning their indices and corresponding parameter vectors.

    Args:
        weight (torch.Tensor): Weight tensor, e.g. shape [out_features, in_features].
        problem_channels (list[int]): List of problematic channel indices.
        dim (int): Dimension along which to extract channels:
                   0 for weight rows, 1 for weight columns.

    Returns:
        List[Tuple[int, torch.Tensor]]: List of (channel index, parameter vector).
    """
    extracted_params = []
    for ch in problem_channels:
        if dim == 0:  # channel corresponds to row
            params = weight[ch, :]  # row vector
        elif dim == 1:  # channel corresponds to column
            params = weight[:, ch]  # column vector
        else:
            raise ValueError("dim must be 0 or 1")
        extracted_params.append((ch, params))
    return extracted_params  # list[tuple[int, torch.Tensor]]


def replace_params_in_weight(w: torch.Tensor, i_channel: list[tuple[int, torch.Tensor]], dim: int = 1):
    """
    Replace rows or columns in weight matrix w using parameters in i_channel.

    Args:
        w (torch.Tensor): Weight matrix, shape [out_features, in_features].
        i_channel (list[tuple[int, torch.Tensor]]): [(channel index, parameter vector), ...]
        dim (int): Dimension, 0 means replace rows, 1 means replace columns.
    """
    for idx, param_vec in i_channel:
        if dim == 0:
            # Replace row idx
            if w.shape[1] != param_vec.shape[0]:
                raise ValueError(f"Weight row dimension {w.shape[1]} and parameter length {param_vec.shape[0]} mismatch")
            w[idx, :] = param_vec
        elif dim == 1:
            # Replace column idx
            if w.shape[0] != param_vec.shape[0]:
                raise ValueError(f"Weight column dimension {w.shape[0]} and parameter length {param_vec.shape[0]} mismatch")
            w[:, idx] = param_vec
        else:
            raise ValueError("dim must be 0 or 1")


def get_metric_from_beta(file_path, layer_index):
    """
    Read metric value from a file for the given layer index.

    Args:
        file_path (str): Path to the metric file.
        layer_index (int): Index of the layer.

    Returns:
        float or None: Metric value if found, else None.
    """
    try:
        with open(file_path, 'r') as f:
            lines = f.readlines()
        if 0 <= layer_index < len(lines):
            line = lines[layer_index].strip()
            # Assume format like "Layer 0 179.620026", last token is the metric
            parts = line.split()
            metric_str = parts[-1]
            metric_value = float(metric_str)
            print("beta", metric_value)
            return metric_value
        else:
            print(f"layer_index {layer_index} out of file line range")
            return None
    except FileNotFoundError:
        print(f"File not found: {file_path}")
        return None
    except Exception as e:
        print(f"Error reading or converting metric: {e}")
        return None



def update_layers(model, linear_layers, data, args: Namespace):
    print(f"Current working directory: {os.getcwd()}")

    for layer_index, name, layer in linear_layers:
        if 'lm_head' in name:
            print(f"Ignoring layer: {name}")
            continue
        # print(f"Attribute layer: {name}")

        layer_attribute_map = {
            "self_attn.q_proj": 1,
            "self_attn.k_proj": 2,
            "self_attn.v_proj": 3,
            "self_attn.o_proj": 4,
            "mlp.gate_proj": 5,
            "mlp.up_proj": 6,
            "mlp.down_proj": 7,
        }
        attribute_name = name.split('.')[-2] + '.' + name.split('.')[-1]  # e.g., self_attn.q_proj
        layer_attribute = layer_attribute_map.get(attribute_name, 0)

        if layer_index == 0 and layer_attribute == 1:
            print(f"Replacing layer at index {layer_index} with attribute {layer_attribute}: {name}")
        else:
            continue  # Skip if not matching condition

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        layer_shape = layer.weight.shape
        w = layer.weight.clone().detach()

        # print("Layer original precision", w.dtype)

        if w.is_cuda:
            w = w.cpu()
        if w.dtype == torch.bfloat16:
            w = w.to(torch.float16)

        # Compression parameters - update here to change compression behavior
        r = args.r
        d = args.d
        g = args.g
        q = args.q
        b = args.b
        c = args.c


        # optimized_params = integer_programming_solver(r, d, g, q, b, c)

        if layer_attribute == 1 or layer_attribute == 4:
            g = 128
        elif layer_attribute == 2 or layer_attribute == 3:
            g = 32
        else:
            g = 256

        if layer_index == 1 and layer_attribute != 2 and layer_attribute != 3:
            g = 256

        if layer_index == 1 and layer_attribute == 2 and layer_attribute == 3:
            g = 64

        dense_matrix, e = adjust_matrix(w, d)
        g = 1  # override group count

        # Get current working directory
        current_dir = os.getcwd()

        # Compose /save directory path
        output_directory = os.path.join(current_dir, 'save')

        # Create /save directory if it does not exist
        if not os.path.exists(output_directory):
            os.makedirs(output_directory)

        # Define file path
        if layer_index >= 0:
            file_path = f'{output_directory}/layer{layer_index}_{layer_attribute}.pt'
        else:
            file_path = f'{output_directory}/layer{layer_index}_{layer_attribute}__new.pt'

        # Check if file exists
        if os.path.exists(file_path):
            # Load tensors from file
            loaded_tensor_dict = torch.load(file_path)
            compress_code_tensor = loaded_tensor_dict['compress_code_tensor'].to(device)
            rq_centroids_tensor = loaded_tensor_dict['rq_centroids_tensor'].to(device)
            original_shape_e = loaded_tensor_dict['original_shape_e']
            if torch.is_tensor(original_shape_e):
                original_shape_e = original_shape_e.to(device)

            print("Tensors loaded successfully.")
        else:
            # File does not exist, perform compression
            print("File not found, performing compression...")

            # Generate compressed matrix
            compress_code_tensor, rq_centroids_tensor, original_shape_e = generate_compressed_dense_matrix_g(
                dense_matrix, q, r, c, b, d, g
            )

            # Create tensor dictionary
            tensor_dict = {
                'compress_code_tensor': compress_code_tensor,
                'rq_centroids_tensor': rq_centroids_tensor,
                'original_shape_e': original_shape_e
            }

            # Save compressed tensors
            torch.save(tensor_dict, file_path)
            print("Compressed tensors saved successfully.")
        # Add the codebook (rq_centroids_tensor) to the global codebook pool here
        # e.g., global_codebook_pool.append(rq_centroids_tensor)

        # Extract problematic channels
        problem_channels = [378, 491]
        extracted_problem_channels = extract_weight_params(w, problem_channels)
        random_channels = [102, 104]
        extracted_random_channels = extract_weight_params(w, random_channels)

        bias = layer.bias if layer.bias is not None else None
        in_features = layer.in_features
        out_features = layer.out_features

        # This is the quantized compressed layer
        new_layer = BitMatrixLayer(
            extracted_random_channels, rq_centroids_tensor, compress_code_tensor, original_shape_e, e, layer_shape,
            q, d, in_features, out_features, layer_index, g, use_checkpoint=False, bias=bias
        )

        # If update is enabled
        if args.q_update:

            print("Starting update")
            # Compute sequence number
            sequence = (layer_index) * 7 + layer_attribute - 1
            print(f"Current sequence number: {sequence}")

            # Real-time activation capture
            if layer_index >= 0:

                dataset = "wikitext-2"
                model_path = args.model_path
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
                data = get_dataset_by_length(dataset, tokenizer)
                sample_0_to_5 = data[0:6]  # Samples 0 to 5 (0 inclusive, 6 exclusive)
                sample_16_to_23 = data[16:24]  # Samples 16 to 23 (16 inclusive, 24 exclusive)
                data = torch.cat((sample_0_to_5, sample_16_to_23), dim=0)

                # data = data[:140]

                print("Shape:", data.shape)
                seq_len = data.size(1)
                inps, forward_args = get_inps_llama_by_linear(model, data, seq_len, device, False, sequence)
                print("inps captured:", inps.shape)
                outs, forward_args = get_outs_llama_by_linear(model, data, seq_len, device, False, sequence)
                print("outs captured:", outs.shape)
            else:
                # Output: load precomputed activations

                # inps_file = f'.../activation/linear_inps/inps_linear_{sequence}.pt'
                current_dir = os.getcwd()
                inps_file = os.path.join(current_dir, 'activation', 'linear_inps', f'inps_linear_{sequence}.pt')
                inps = torch.load(inps_file)
                # outs_file = f'.../activation/linear_out/outs_linear_{sequence}.pt'
                outs_file = os.path.join(current_dir, 'activation', 'linear_out', f'inps_linear_{sequence}.pt')
                outs = torch.load(outs_file)

            # Calculate number of samples to select
            num_samples = inps.size(0)
            num_samples_to_select = num_samples // 10

            # Randomly select sample indices
            indices = random.sample(range(num_samples), num_samples_to_select)

            # Select inputs and outputs using indices
            selected_inps = inps[indices]
            selected_outs = outs[indices]

            new_layer, is_better = update_groupwise(
                layer=new_layer,
                train_inps=inps,
                train_outs=outs,
                args=args,
                valid_inps=selected_inps,
                valid_outs=selected_outs,
                verbose=True  # Optionally print detailed info
            )

            # If improved, save new tensors
            if is_better:
                tensor_dict = {
                    'compress_code_tensor': new_layer.q_weight,
                    'rq_centroids_tensor': new_layer.rq_centroids,
                    'original_shape_e': original_shape_e
                }
                torch.save(tensor_dict, f'{output_directory}/layer{layer_index}_{layer_attribute}_new.pt')

        weight1 = generate_decompressed_dense_matrix_g(
            new_layer.q_weight,
            new_layer.rq_centroids.to(torch.float16),
            new_layer.eshape,
            new_layer.e,
            new_layer.original_shape,
            new_layer.d
        )

        if args.q_compensate:
            delta_W = weight1.to(w.device).to(w.dtype) - w
            fro_norm_sq = torch.sum(delta_W ** 2).item()
            sequence = (layer_index) * 7 + layer_attribute - 1
            print(f"Current sequence number: {sequence}")

            # # Append to file example (commented out)
            # with open('beta1.txt', 'a') as f:
            #     f.write(f"Layer {sequence}  {fro_norm_sq:.6f}\n")

            start_time = time.time()  # Record start time

            # 1. Compute current metric
            # beta: extract metric from beta.txt for layer_index
            metric_value = get_metric_from_beta('beta.txt', layer_index)
            print("beta", metric_value)

            gate1 = args.gate1
            gate2 = args.gate2
            if metric_value < gate1 and layer_attribute != 7:  # Less than threshold 1, process only 7th linear layer model
                print("model")

            elif metric_value > gate1 and metric_value < gate2 and (layer_attribute != 4 and layer_attribute != 7):  # transformer
                print("transformer")

            else:  # Compensation

                # 2.1 Prepare input data
                inps_all, forward_args = get_inps_llama_by_linear(model, data, 2048, device, False, sequence)
                inps = inps_all[2].squeeze(0)  # Shape (2048, input_dim)

                # Outputs before and after quantization
                device = inps.device  # Use device of inps

                w = w.to(device)
                weight1 = weight1.to(device)

                output = torch.matmul(inps, w.T)
                output_q = torch.matmul(inps, weight1.T)

                # Error matrix
                output_e = output - output_q

                # Find top 1% columns with max error
                problem_channels = find_top1percent_max_error_coords(output_e, 'top')
                # print("Random top 10 row indices:", problem_channels)
                best_metric = float('inf')
                best_candidate = None

                for i in range(3):
                    # Ensure different candidates each time
                    candidate = sample_sparse_params_by_columns(delta_W, problem_channels, 0.16, 'top')

                    candidate = candidate.to(weight1.device)

                    metric = compute_quantization_mrtric(inps, w, weight1, candidate)
                    metric = metric * 1_000_000
                    print(f"Candidate compensation score {i+1}: {metric.item():.6f}")

                    if metric.item() < best_metric:
                        best_metric = metric.item()
                        best_candidate = candidate

                print(f"Selected best compensation MSE: {best_metric:.6f}")
                weight1 = weight1 - best_candidate.to(weight1.device)
            end_time = time.time()

        # 4. Replace layer in model (optional)
        # Fake quantization
        weight_to_use = weight1.to(layer.weight.device).to(layer.weight.dtype)
        # weight_to_use = (weight1 + best_candidate).to(layer.weight.device).to(layer.weight.dtype)
        # Fake quantization copy
        with torch.no_grad():
            layer.weight.data.copy_(weight_to_use)

        # Actual quantization replacement
        # parent_module = dict(model.named_modules())[name.rsplit('.', 1)[0]]
        # setattr(parent_module, name.split('.')[-1], new_layer)

        torch.cuda.empty_cache()
        print(f"Completed layer replacement: {name}")




def update_layers_by_transformer(model, linear_layers, args: Namespace):
    print(f"Current working directory: {os.getcwd()}")

    for layer_index, name, layer in linear_layers:
        if 'lm_head' in name:
            print(f"Ignoring layer: {name}")
            continue
        print(f"Attribute layer: {name}")

        layer_attribute_map = {
            "self_attn.q_proj": 1,
            "self_attn.k_proj": 2,
            "self_attn.v_proj": 3,
            "self_attn.o_proj": 4,
            "mlp.gate_proj": 5,
            "mlp.up_proj": 6,
            "mlp.down_proj": 7,
        }
        attribute_name = name.split('.')[-2] + '.' + name.split('.')[-1]  # e.g. self_attn.q_proj
        layer_attribute = layer_attribute_map.get(attribute_name, 0)

        if layer_index == 0:
            # if layer_index == 0 or layer_index == 4:
            print(f"Replacing layer at index {layer_index} with attribute {layer_attribute}: {name}")
        else:
            continue  # Skip attribute 6 for all layers

        layer_shape = layer.weight.shape
        w = layer.weight.clone().detach()
        print("Layer original precision", w.dtype)

        if w.is_cuda:
            w = w.cpu()
        if w.dtype == torch.bfloat16:
            w = w.to(torch.float16)

        # Compression process - change parameters here if needed
        r = args.r
        d = args.d
        g = args.g
        q = args.q
        b = args.b
        c = args.c
        dense_matrix, e = adjust_matrix(w, d)

        output_directory = '...'  # Replace personal path with placeholder or desired path
        # Define file path
        file_path = f'{output_directory}/layer{layer_index}_{layer_attribute}.pt'

        # Check if file exists
        if os.path.exists(file_path):
            # Load tensors from file
            loaded_tensor_dict = torch.load(file_path)
            compress_code_tensor = loaded_tensor_dict['compress_code_tensor']
            rq_centroids_tensor = loaded_tensor_dict['rq_centroids_tensor']
            original_shape_e = loaded_tensor_dict['original_shape_e']
            print("Tensors loaded successfully.")
        else:
            # File does not exist, perform compression
            print("File not found, performing compression...")

            # Generate compressed matrix
            compress_code_tensor, rq_centroids_tensor, original_shape_e = generate_compressed_dense_matrix_g(
                dense_matrix, q, r, 256, b, d, g
            )

            # Create tensor dictionary
            tensor_dict = {
                'compress_code_tensor': compress_code_tensor,
                'rq_centroids_tensor': rq_centroids_tensor,
                'original_shape_e': original_shape_e
            }

            # Save compressed tensors
            torch.save(tensor_dict, file_path)
            print("Compressed tensors saved successfully.")

        problem_channels = [378, 491]
        extracted_problem_channels = extract_weight_params(w, problem_channels)
        random_channels = [102, 104]
        extracted_random_channels = extract_weight_params(w, random_channels)

        # Decompression
        r_w = generate_decompressed_dense_matrix_g(compress_code_tensor, rq_centroids_tensor, original_shape_e, e, layer_shape, d)

        w = w.cuda()

        # Decompress data

        if w.is_cuda:
            r_w = r_w.cuda()  # Ensure on same device

        mse = torch.mean((w - r_w) ** 2)
        print("MSE before and after compression:", mse.item())
        # TODO: compute MSE before and after compression
        r_w = r_w.to(torch.bfloat16)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        max_diff = torch.max(torch.abs(w.to(device) - r_w.to(device)))
        print(f"Maximum difference between original and decompressed weights: {max_diff.item()}")

        bias = layer.bias if layer.bias is not None else None
        in_features = layer.in_features
        out_features = layer.out_features

        # Extract outlier coordinates
        topk_weight_indices = torch.load("topk_weight_indices.pt")  # ([2048, 2])
        print("Shape:", topk_weight_indices.shape)  # Print tensor shape
        print("topk_weight_indices dtype:", topk_weight_indices.dtype)  # Print tensor dtype int64
        # TODO: Transform this data to indices using two coordinates combined with w's shape to get a large index,
        # use torch.int32 dtype, then extract corresponding w values and record as a 1D tensor (column vector),
        # positions correspond to topk_weight_indices

        # Move topk_weight_indices to the same device as w
        topk_weight_indices = topk_weight_indices.to(w.device)

        rows = topk_weight_indices[:, 0]
        cols = topk_weight_indices[:, 1]

        # Shape of w
        num_rows, num_cols = w.shape

        # Compute linear indices, convert to int32 type
        linear_indices = rows * num_cols + cols
        linear_indices = linear_indices.to(torch.int32)

        # Extract corresponding elements
        w_flat = w.view(-1)
        extracted_values = w_flat[linear_indices]

        # Convert to column vector
        extracted_values = extracted_values.unsqueeze(1)  # shape = [2048, 1], dtype fp16
        # TODO: Combine linear_indices: int32 and extracted_values: fp16 to create a new tensor of two columns,
        # first column linear_indices, second column extracted_values

        outliers = {
            'linear_indices': linear_indices,  # shape [2048, 1], dtype int32
            'extracted_values': extracted_values  # shape [2048, 1], dtype float16
        }

        # TODO: Pass outliers to this class, how to define this class
        new_layer = BitMatrixLayer(
            extracted_random_channels, rq_centroids_tensor, compress_code_tensor, original_shape_e,
            e, layer_shape, q, d, in_features, out_features, layer_index, g,
            use_checkpoint=False, bias=bias, outliers=outliers
        )

        # Start update
        if layer_attribute == 7:

            print("Starting update")
            # Compute sequence number
            sequence = (layer_index) * 7 + layer_attribute - 1
            print(f"Current sequence number: {sequence}")

            if layer_attribute == 99 and layer_index == 4:

                dataset = "wikitext-2"
                model_path = '...'  # Replace personal path with placeholder
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
                data = get_dataset_by_length(dataset, tokenizer)
                # Select data by indices from file greater_than_0.03.txt
                filename = '.../greater_than_0.03.txt'
                with open(filename, 'r') as f:
                    lines = f.readlines()
                indexes = [int(line.strip()) for line in lines]
                selected_tensors = []
                for i in indexes:
                    if 0 <= i < len(data):
                        sample = data[i]
                        selected_tensors.append(sample)
                data = torch.stack(selected_tensors, dim=0)  # Concatenate along dim 0

                print("Shape:", data.shape)
                seq_len = data.size(1)
                inps, forward_args = get_inps_llama_by_linear(model, data, seq_len, device, False, 34)
            else:
                inps_file = f'.../activation/linear_inps/inps_linear_{sequence}.pt'
                inps = torch.load(inps_file)
            outs_file = f'.../activation/linear_out/outs_linear_{sequence}.pt'
            outs = torch.load(outs_file)

            # Calculate number of samples to select
            num_samples = inps.size(0)
            num_samples_to_select = num_samples // 10

            # Randomly select sample indices
            indices = random.sample(range(num_samples), num_samples_to_select)

            # Select inputs and outputs using indices
            selected_inps = inps[indices]
            selected_outs = outs[indices]

            topk_weight_indices = compensation_groupwise(
                layer=new_layer,
                train_inps=inps,
                train_outs=outs,
                args=args,
                valid_inps=inps,
                valid_outs=outs,
                verbose=True  # Optionally print detailed info
            )
            # TODO: Save topk_weight_indices to file and provide extraction method

            torch.save(topk_weight_indices, "topk_weight_indices.pt")

        # 4. Save to file then in compression layer pick 2048 largest errors and assign to quantized layer
        parent_module = dict(model.named_modules())[name.rsplit('.', 1)[0]]
        setattr(parent_module, name.split('.')[-1], new_layer)
        torch.cuda.empty_cache()
        print(f"Completed layer replacement: {name}")

        # After quantizing all linear layers, update designated transformer layers
        if layer_attribute == 99:
            print("All linear layers quantized, starting to update specified transformer layers.")

            # Load input and output data
            inps_file = f'.../activation/transformer_inps/inps_transformer_{layer_index}.pt'
            inps = torch.load(inps_file)
            outs_file = f'.../activation/transformer_out/outs_transformer_{layer_index}.pt'
            outs = torch.load(outs_file)

            # Assume the transformer layer name to update is like this
            transformer_layer_name = f"model.layers.{layer_index}"

            # Get the layer to update
            transformer_layer = dict(model.named_modules())[transformer_layer_name]

            # Perform update
            transformer_layer = update_groupwise(
                layer=transformer_layer,  # Only update specified layer
                train_inps=inps,
                train_outs=outs,
                args=args,
                valid_inps=inps,  # Set validation set as needed
                valid_outs=outs,
                verbose=True
            )

            # Replace layer in model
            parent_module = dict(model.named_modules())[transformer_layer_name.rsplit('.', 1)[0]]
            setattr(parent_module, transformer_layer_name.split('.')[-1], transformer_layer)


def quantize_and_compress(layer_index, name, layer, best_params_dir, log_dir):
    # Skip layer if lm_head
    if 'lm_head' in name:
        print(f"Ignoring layer: {name}")
        return None

    print(f"Attribute layer: {name}")

    layer_attribute_map = {
        "self_attn.q_proj": 1,
        "self_attn.k_proj": 2,
        "self_attn.v_proj": 3,
        "self_attn.o_proj": 4,
        "mlp.gate_proj": 5,
        "mlp.up_proj": 6,
        "mlp.down_proj": 7,
    }

    attribute_name = name.split('.')[-2] + '.' + name.split('.')[-1]  # e.g., self_attn.q_proj
    layer_attribute = layer_attribute_map.get(attribute_name, 0)

    layer_shape = layer.weight.shape
    w = layer.weight.clone().detach()
    original_mean = torch.mean(w)
    original_median = torch.median(w)

    if w.is_cuda:
        w = w.cpu()
    if w.dtype == torch.bfloat16:
        w = w.to(torch.float16)

    best_params_filename = os.path.join(best_params_dir, f'best_params_layer_{layer_index}_attr_{layer_attribute}.json')
    log_data_filename = os.path.join(log_dir, f'log_data_layer_{layer_index}_attr_{layer_attribute}.json')
    print(f"Current memory usage: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
    try:
        q_best, r_best, c_best, b_best, d_best = optimize_compression_params(w, layer_index, layer_attribute, best_params_filename, log_data_filename)
        rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, q_best, r_best, c_best, b_best, d_best)

        # Check MSE and regenerate compression matrix if needed
        print("Median value", original_median)
        print("Mean value", original_mean)

        if layer_index < 6:
            if mse_temp > original_mean * 10000:
                torch.cuda.empty_cache()  # Clear cache
                if layer_attribute in [2]:
                    rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 16)
                elif layer_attribute in [3]:
                    rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 16)
                elif layer_attribute in [1, 4]:
                    rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 32)
                elif layer_attribute in [5, 6, 7]:
                    rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 32)

            if mse_temp > original_mean * 10000:
                # torch.cuda.empty_cache()  # Clear cache
                print(f"Layer index: {q_best}, Layer attribute: {layer_attribute}, MSE: {mse_temp} MSE not improved first time")
                # rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 16, 256, 1, 32)

            if mse_temp > original_mean * 10000:
                print(f"Layer index: {q_best}, Layer attribute: {layer_attribute}, MSE: {mse_temp} MSE not improved second time")
        else:
            if mse_temp > original_mean * 1000:
                torch.cuda.empty_cache()  # Clear cache
                if layer_attribute in [2]:
                    rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 16)
                elif layer_attribute in [3]:
                    rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 16)
                elif layer_attribute in [1, 4]:
                    rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 32)
                elif layer_attribute in [5, 6, 7]:
                    rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 32)

            if mse_temp > original_mean * 10000:
                # torch.cuda.empty_cache()  # Clear cache
                print(f"Layer index: {q_best}, Layer attribute: {layer_attribute}, MSE: {mse_temp} MSE not improved first time")
                # rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 32)

            if mse_temp > original_mean * 10000:
                print(f"Layer index: {q_best}, Layer attribute: {layer_attribute}, MSE: {mse_temp} MSE not improved second time")

        # Decompress data
        r_w = generate_decompressed_matrix(rq_centroids, pca_layer, original_shape, q)
        print(f"Current memory usage: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
        max_value = pca_layer.max()
        min_value = pca_layer.min()
        print(f"Max value of pca_layer: {max_value.item()}")
        print(f"Min value of pca_layer: {min_value.item()}")

        r_w = r_w.to(torch.bfloat16)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        max_diff = torch.max(torch.abs(w.to(device) - r_w.to(device)))
        print(f"Layer {name} processed, max difference: {max_diff.item()}")

        # Delete r_w and clear cache
        del r_w
        torch.cuda.empty_cache()
        return (name, layer, rq_centroids, pca_layer, original_shape, q, "Success")

    except Exception as e:
        print(f"Error processing layer {name}: {e}")
        return (name, layer, None, None, None, None, "Failed")



def update_layers_parallel(model, linear_layers):
    print(f"Current working directory: {os.getcwd()}")
    log_dir = 'params_log'
    best_params_dir = 'params_best'
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(best_params_dir, exist_ok=True)

    with ThreadPoolExecutor(max_workers=2) as executor:
        futures = {
            executor.submit(quantize_and_compress, layer_index, name, layer, best_params_dir, log_dir): (layer_index, name)
            for layer_index, name, layer in linear_layers
        }

        results = []
        for future in as_completed(futures):
            layer_index, layer_name = futures[future]
            try:
                result = future.result()
                if result is not None:
                    results.append(result)
                    # Print each thread's processing status
                    print(f"Layer {layer_name} processing status: {result[-1]}")  # Print status info
            except Exception as e:
                print(f"Error during quantization of layer {layer_name}: {e}")

    # Replace layers in the model
    for name, layer, rq_centroids, pca_layer, original_shape, q, _ in results:
        if rq_centroids is None:  # Skip failed layers
            continue

        bias = layer.bias if layer.bias is not None else None
        in_features = layer.in_features
        out_features = layer.out_features
        new_layer = BitMatrixLayer(rq_centroids, pca_layer, original_shape, q,
                                   in_features, out_features, layer_index,
                                   use_checkpoint=False, bias=bias)

        # Replace layer
        parent_module = dict(model.named_modules())[name.rsplit('.', 1)[0]]
        setattr(parent_module, name.split('.')[-1], new_layer)
        torch.cuda.empty_cache()
        print(f"Completed layer replacement: {name}")


class BitMatrixLayer(nn.Module):
    """
    Defines the replacement layer's computation for bit matrix operations in neural networks.

    Parameters:
    rq_centroids : torch.Tensor
        Codebook tensor with shape (r, c, d), where r is number of quantizers,
        c is number of centroids per quantizer, d is feature dimension.
    pca_layer : dict
        Contains PCA layer information for weight initialization.
    original_shape : tuple
        Original matrix shape, typically used for decoding.
    q : int
        Scaling factor used for adjusting matrix values.
    in_features : int
        Number of input features.
    out_features : int
        Number of output features.
    layer_index : int
        Current layer index for unique identification.
    bias : Optional[nn.Parameter]
        Optional bias parameter, defaults to None.
    use_checkpoint : bool
        Whether to use checkpointing to save memory.
    """

    def __init__(self,
                 i_channel: list,
                 rq_centroids_tensor,
                 compress_code_tensor,
                 original_shape_e: tuple,
                 e: dict,
                 original_shape: tuple,
                 q: int,
                 d: int,
                 in_features: int,
                 out_features: int,
                 layer_index: int,
                 group_num: int,
                 bias: Optional[nn.Parameter] = None,
                 outliers: Optional[Dict[str, torch.Tensor]] = None,
                 use_checkpoint: bool = False):
        super(BitMatrixLayer, self).__init__()

        # Residual vector
        self.e = e
        self.eshape = original_shape_e
        # Outlier channels
        self.i_channel = i_channel
        self.bias = bias
        self.q = q
        self.d = d
        # self.rq_centroids = rq_centroids_tensor
        self.out_features = out_features  # number of output features
        self.in_features = in_features  # number of input features
        self.use_checkpoint = use_checkpoint  # whether to use checkpoint
        self.group_num = group_num

        self.layer_index = layer_index
        self.original_shape = original_shape
        self.outliers = outliers  # dictionary of outliers containing linear_indices and extracted_values

        # Initialize weight parameters
        # self.rq_centroids = nn.Parameter(rq_centroids_tensor, requires_grad=True)
        self.rq_centroids = nn.Parameter(rq_centroids_tensor.float(), requires_grad=True)  # for update use
        self.q_weight = nn.Parameter(compress_code_tensor, requires_grad=False)

    def forward(self, input: torch.Tensor):
        if self.use_checkpoint and torch.is_grad_enabled():
            return checkpoint(self._forward, input, use_reentrant=False)
        return self._forward(input)

    def _forward(self, input: torch.Tensor):

        weight = generate_decompressed_dense_matrix_g(self.q_weight,
                                                      self.rq_centroids.to(torch.float16),
                                                      self.eshape,
                                                      self.e,
                                                      self.original_shape,
                                                      self.d)
        # print(f"Output shape: {weight.shape}, output dtype: {weight.dtype}")
        # TODO: Replace positions in weight with those in outliers if any
        # if self.outliers is not None:
        #     linear_indices = self.outliers['linear_indices']
        #     extracted_values = self.outliers['extracted_values']
        #     # Ensure linear_indices and extracted_values are on same device
        #     linear_indices = linear_indices.to(weight.device)
        #     extracted_values = extracted_values.to(weight.device)
        #     # Place extracted values into specified weight positions
        #     weight_flat = weight.view(-1)  # flatten weight to 1D
        #     weight_flat[linear_indices] = extracted_values.squeeze()
        #     weight = weight_flat.view(weight.shape)  # reshape to original shape
        # print(f"Weight shape: {weight.shape}, weight dtype: {weight.dtype}")
        # print(f"Input shape: {input.shape}, input dtype: {input.dtype}")

        # weight = weight.to(torch.float16).to('cuda')  # convert to float16 and move to GPU

        # Assuming input channels correspond to columns of weight
        # replace_params_in_weight(weight, self.i_channel, dim=1)

        if input.dtype != weight.dtype:
            input = input.to(weight.dtype)  # convert input to weight dtype

        # Perform linear transformation using weight
        output = F.linear(input, weight, self.bias)
        # print(f"Output shape: {output.shape}, output dtype: {output.dtype}")
        if torch.isnan(output).any() or torch.isinf(output).any():
            # Print input and weight info
            print(f"Input shape: {input.shape}, dtype: {input.dtype}, content: {input}")
            print(f"Weight shape: {weight.shape}, dtype: {weight.dtype}, content: {weight}")
            # print(f"Bias shape: {self.bias.shape}, dtype: {self.bias.dtype}, content: {self.bias}")

            # Optionally raise exception
            raise ValueError("Output contains NaN or Inf.")

        # Immediately free weight after use
        del weight  # delete weight tensor
        torch.cuda.empty_cache()  # clear unused cache (optional)

        return output


# def print_memory_usage():
#     # Print system memory usage
#     mem_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024  # Convert to MB
#     print(f"Process Memory Usage: {mem_usage:.2f} MB")


def main():
    torch.set_num_threads(min(16, torch.get_num_threads()))
    torch.backends.cudnn.allow_tf32 = False
    torch.backends.cuda.matmul.allow_tf32 = False

    parser = argparse.ArgumentParser(add_help=True)

    # update related parameters
    parser.add_argument(
        "--update_max_epochs",
        type=int,
        default=10,
        help="Run this many passes over training data during update; no update if set to 0."
    )
    parser.add_argument(
        "--update_early_stop",
        type=int,
        default=5,
        help="Stop update if loss doesn't improve after this many epochs."
    )
    parser.add_argument(
        "--update_lr",
        type=float,
        default=1e-5,
        help="update learning rate."
    )
    parser.add_argument(
        "--update_batch_size",
        type=int,
        default=1,
        help="(update only) train on batches of this many sequences globally across all GPUs."
    )
    parser.add_argument(
        "--update_adam_beta1",
        type=float,
        default=0.9,
        help="update Adam optimizer beta1 parameter."
    )
    parser.add_argument(
        "--update_adam_beta2",
        type=float,
        default=0.95,
        help="update Adam optimizer beta2 parameter."
    )
    parser.add_argument("--update_keep_best", action="store_true", help="Keep best model parameters during update.")
    parser.add_argument(
        "--local_batch_size",
        type=int,
        default=4,
        help="(update only) Per-device and per-forward-pass batch size to accumulate global --batch_size."
    )
    parser.add_argument(
        "--val_size",
        type=int,
        default=128,
        help="Number of validation sequences."
    )
    parser.add_argument(
        "--print_frequency",
        type=int,
        default=10,
        help="Print Adam progress every print_frequency updates."
    )

    # Additional parameters
    parser.add_argument(
        "--q",
        type=int,
        default=0,  # default value can be changed as needed
        help="Parameter q description."
    )

    parser.add_argument(
        "--r",
        type=int,
        default=2,  # default value can be changed as needed
        help="Parameter r description."
    )

    parser.add_argument(
        "--c",
        type=int,
        default=256,  # default value can be changed as needed
        help="Parameter c description."
    )

    parser.add_argument(
        "--b",
        type=int,
        default=1,  # default value can be changed as needed
        help="Parameter b description."
    )

    parser.add_argument(
        "--d",
        type=int,
        default=8,  # default value can be changed as needed
        help="Parameter d description."
    )
    parser.add_argument(
        "--g",
        type=int,
        default=256,  # default value can be changed as needed
        help="Parameter g description."
    )
    parser.add_argument(
        "--gate1",
        type=float,
        default=0.0,
        help="Parameter gate1 description."
    )

    parser.add_argument(
        "--gate2",
        type=float,
        default=0.0,
        help="Parameter gate2 description."
    )

    parser.add_argument(
        '--model_path',
        type=str,
        default='...',  # Replace personal path with placeholder
        help='Path to the model directory'
    )
    # Whether to update (if argument passed, True; default False)
    parser.add_argument(
        '--q_update',
        action='store_true',
        help='Whether to perform update (default: False)'
    )

    # Whether to compensate (if argument passed, True; default False)
    parser.add_argument(
        '--q_compensate',
        action='store_true',
        help='Whether to perform compensation (default: False)'
    )

    parser.add_argument('--offload_activations', type=bool, default=False, help='Offload activations to CPU')

    # Parse arguments
    args = parser.parse_args()
    args.devices = [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())]

    # Load model
    print("Loading model")
    model_path = args.model_path
    model = load_safetensors_model(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
    linear_layers = get_Linears(model)

    # Evaluation data
    input_text = "Hello, how are you?"
    dataset = "wikitext-2"
    data = get_dataset_by_length(dataset, tokenizer)
    sample = data[0:4]
    # TODO: print shape of data
    # Compute metric 1 during quantization and write to file beta.txt

    print("Start replacing model layers")
    update_layers(model, linear_layers, sample, args)

    print("Evaluation after quantization")
    print("Before evaluation:")
    # print_memory_usage()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    eval(model, tokenizer, input_text, dataset)




new_directory = os.getcwd()

try:
    # Change current working directory
    os.chdir(new_directory)
    print(f"Current working directory changed to: {os.getcwd()}")
except FileNotFoundError:
    print(f"Directory '{new_directory}' does not exist. Please check the path.")
except PermissionError:
    print(f"No permission to access directory '{new_directory}'.")
except Exception as e:
    print(f"Error occurred while changing directory: {e}")
torch.cuda.empty_cache()


# Configure logging
log_dir = os.path.join(os.getcwd(), 'LOG')
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

# Get current time string, e.g. 20240610_153045
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

# Log file path with timestamp
log_file = os.path.join(log_dir, f'llama_{timestamp}.log')
logging.basicConfig(
    level=logging.INFO,  # Set log level
    format='%(asctime)s - %(levelname)s - %(message)s',  # Log format
    handlers=[
        logging.FileHandler(log_file),  # Write logs to file
        logging.StreamHandler(sys.stdout)  # Also output to console
    ]
)

# Redirect print() to logging
class Logger(object):
    def __init__(self):
        self.terminal = sys.stdout
        self.log = open(log_file, "a")  # Open log file in append mode

    def write(self, message):
        self.terminal.write(message)  # Output to console
        self.log.write(message)  # Output to log file

    def flush(self):
        pass  # This method can be used to flush the buffer

# Redirect sys.stdout to Logger
sys.stdout = Logger()


if __name__ == "__main__":
    fix_random_seed(2002)
    main()
