# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from src.activation_utils import get_inps_llama_by_linear,get_inps_llama_by_linear
from src.data_utils import get_dataset_by_length
from typing import Dict, Optional
import os
from src.quant_utils import generate_compressed_matrix,generate_decompressed_matrix,generate_compressed_dense_matrix_g,optimize_compression_params,fix_random_seed,adjust_matrix, generate_decompressed_dense_matrix_g
from tqdm import tqdm
from  src.model_utils import  load_safetensors_model,save_model
from  src.eval_utils import eval
import torch
import torch.nn as nn
from src.compensation_utils import update_groupwise,compensation_groupwise
from torch.cuda.amp import autocast, GradScaler
import argparse
import random
from argparse import Namespace
from concurrent.futures import ThreadPoolExecutor, as_completed
import sys
import torch
import logging

os.environ["CUDA_VISIBLE_DEVICES"] = "0"


def get_Linears(model):
    linear_layers = []
    for name, layer in model.named_modules():
        if isinstance(layer, torch.nn.Linear):
            layer_index = None
            if 'layers.' in name:
                layer_index_str = name.split('layers.')[1].split('.')[0]
                layer_index = int(layer_index_str)

            print(
                f"Layer Index: {layer_index}, Layer Name: {name}, Input Features: {layer.in_features}, Output Features: {layer.out_features}")
            linear_layers.append((layer_index, name, layer))
    return linear_layers


def extract_weight_params(weight: torch.Tensor, problem_channels: list, dim: int = 1):
    """
    Extract parameters from specified channels in the weight tensor, returning channel indices and corresponding parameters.

    Args:
        weight (torch.Tensor): Weight tensor, shape e.g. [out_features, in_features]
        problem_channels (list[int]): List of problematic channel indices
        dim (int): Dimension to select, 0 means channels correspond to rows, 1 means columns

    Returns:
        List[Tuple[int, torch.Tensor]]: Each element is (channel index, corresponding weight parameter vector)
    """
    extracted_params = []
    for ch in problem_channels:
        if dim == 0:  # Channels correspond to rows
            params = weight[ch, :]  # Row parameter vector
        elif dim == 1:  # Channels correspond to columns
            params = weight[:, ch]  # Column parameter vector
        else:
            raise ValueError("dim parameter must be 0 or 1")
        extracted_params.append((ch, params))
    return extracted_params  # list[tuple[int, torch.Tensor]]


def replace_params_in_weight(w: torch.Tensor, i_channel: list[tuple[int, torch.Tensor]], dim: int = 1):
    """
    Replace rows or columns in weight matrix w with parameters from i_channel.

    Args:
        w (torch.Tensor): Weight matrix, shape [out_features, in_features]
        i_channel (list[tuple[int, torch.Tensor]]): [(channel index, corresponding parameter vector), ...]
        dim (int): Dimension, 0 means replace rows of w, 1 means replace columns of w
    """
    for idx, param_vec in i_channel:
        if dim == 0:
            # Replace row idx
            if w.shape[1] != param_vec.shape[0]:
                raise ValueError(
                    f"Weight row dimension {w.shape[1]} does not match parameter length {param_vec.shape[0]}")
            w[idx, :] = param_vec
        elif dim == 1:
            # Replace column idx
            if w.shape[0] != param_vec.shape[0]:
                raise ValueError(
                    f"Weight column dimension {w.shape[0]} does not match parameter length {param_vec.shape[0]}")
            w[:, idx] = param_vec
        else:
            raise ValueError("dim parameter must be 0 or 1")


def update_layers(model, linear_layers, args: Namespace):
    print(f"Current working directory: {os.getcwd()}")

    for layer_index, name, layer in linear_layers:
        if 'lm_head' in name:
            print(f"Ignoring layer: {name}")
            continue
        print(f"Attribute layer: {name}")

        layer_attribute_map = {
            "self_attn.q_proj": 1,
            "self_attn.k_proj": 2,
            "self_attn.v_proj": 3,
            "self_attn.o_proj": 4,
            "mlp.gate_proj": 5,
            "mlp.up_proj": 6,
            "mlp.down_proj": 7,
        }
        attribute_name = name.split('.')[-2] + '.' + name.split('.')[-1]  # self_attn.q_proj
        layer_attribute = layer_attribute_map.get(attribute_name, 0)

        # if layer_index > 32:
        # # if layer_index == 1 and  layer_attribute == 7:
        #     print(f"Replacing layer at index {layer_index} with attribute {layer_attribute}: {name}")

        # else:
        #     continue  # Skip attribute 6 for all layers

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        layer_shape = layer.weight.shape
        w = layer.weight.clone().detach()
        print("Original layer precision", w.dtype)

        if w.is_cuda:
            w = w.cpu()
        if w.dtype == torch.bfloat16:
            w = w.to(torch.float16)
        # The following code is the compression process. Modify here to change compression behavior.

        r = args.r
        d = args.d
        g = args.g
        q = args.q
        b = args.b
        c = args.c
        dense_matrix, e = adjust_matrix(w, d)

        output_directory = '.../compress_weight_gemam4'
        # Define file path
        if layer_index >= 0:
            file_path = f'{output_directory}/layer{layer_index}_{layer_attribute}.pt'
        else:
            file_path = f'{output_directory}/layer{layer_index}_{layer_attribute}__new.pt'

        # Check if file exists
        if os.path.exists(file_path):
            # Load tensors from file
            loaded_tensor_dict = torch.load(file_path)
            # Extract tensors
            compress_code_tensor = loaded_tensor_dict['compress_code_tensor'].to(device)

            rq_centroids_tensor = loaded_tensor_dict['rq_centroids_tensor'].to(device)

            original_shape_e = loaded_tensor_dict['original_shape_e']
            if torch.is_tensor(original_shape_e):
                original_shape_e = original_shape_e.to(device)
            # TODO: Load all tensors to CUDA device
            print("Tensors loaded successfully.")

        else:
            # File does not exist, perform compression
            print("File not found, performing compression...")
            continue
            # Generate compressed matrix
            compress_code_tensor, rq_centroids_tensor, original_shape_e = generate_compressed_dense_matrix_g(
                dense_matrix, q, r, c, b, d, g
            )

            # Create tensor dictionary
            tensor_dict = {
                'compress_code_tensor': compress_code_tensor,
                'rq_centroids_tensor': rq_centroids_tensor,
                'original_shape_e': original_shape_e
            }

            # Save compressed tensors
            torch.save(tensor_dict, file_path)
            print("Compressed tensors saved successfully.")

        problem_channels = [378, 491]
        extracted_problem_channels = extract_weight_params(w, problem_channels)
        random_channels = [102, 104]
        extracted_random_channels = extract_weight_params(w, random_channels)

        bias = layer.bias if layer.bias is not None else None
        in_features = layer.in_features
        out_features = layer.out_features


        new_layer = BitMatrixLayer(extracted_random_channels, rq_centroids_tensor, compress_code_tensor,
                                   original_shape_e, e, layer_shape, q, d, in_features, out_features, layer_index, g,
                                   use_checkpoint=False, bias=bias)

        # # Start fine tuning
        if layer_index > 99 or (layer_index == 99 and layer_attribute > 5):

            print("Starting fine tuning")
            # Calculate sequence number
            sequence = (layer_index) * 7 + layer_attribute - 1
            print(f"Current sequence number: {sequence}")
            if layer_index >= 0:

                dataset = "wikitext-2"
                model_path = '.../model/gemma2_9B'
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
                data = get_dataset_by_length(dataset, tokenizer)
                # sample_0_to_5 = data[0:6]  # Take samples 0 to 5 (include 0, exclude 6)
                # sample_16_to_23 = data[16:24]  # Take samples 16 to 23 (include 16, exclude 24)
                # data = torch.cat((sample_0_to_5, sample_16_to_23), dim=0)
                # data = data[:140]

                print("Shape:", data.shape)
                seq_len = data.size(1)
                inps, forward_args = get_inps_llama_by_linear(model, data, seq_len, device, False, sequence)
                print("Inputs captured:", inps.shape)
            else:
                inps_file = f'.../activation/linear_inps/inps_linear_{sequence}.pt'
                inps = torch.load(inps_file)
            outs_file = f'.../activation/linear_out/outs_linear_{sequence}.pt'
            outs = torch.load(outs_file)
            print("Outputs captured:", outs.shape)

            # Calculate number of samples to select
            num_samples = inps.size(0)
            num_samples_to_select = num_samples // 10

            # Randomly select indices of samples
            indices = random.sample(range(num_samples), num_samples_to_select)

            # Select inputs and outputs using indices
            selected_inps = inps[indices]
            selected_outs = outs[indices]

            new_layer, is_better = update_groupwise(
                layer=new_layer,
                train_inps=inps,
                train_outs=outs,
                args=args,
                valid_inps=selected_inps,
                valid_outs=selected_outs,
                verbose=True  # Optionally print detailed info
            )
            # Check if fine tuning improved, save new if yes
            if is_better:
                tensor_dict = {
                    'compress_code_tensor': new_layer.q_weight,
                    'rq_centroids_tensor': new_layer.rq_centroids,
                    'original_shape_e': original_shape_e
                }
                torch.save(tensor_dict, f'{output_directory}/layer{layer_index}_{layer_attribute}_new.pt')

        # Dequantization:

        # 4. Replace layer in model (optional)
        # Pseudo-quantization
        weight1 = generate_decompressed_dense_matrix_g(new_layer.q_weight, new_layer.rq_centroids.to(torch.float16),
                                                       new_layer.eshape, new_layer.e, new_layer.original_shape,
                                                       new_layer.d)
        weight_to_use = weight1.to(layer.weight.device).to(layer.weight.dtype)
        with torch.no_grad():
            layer.weight.data.copy_(weight_to_use)
        # Real quantization
        # parent_module = dict(model.named_modules())[name.rsplit('.', 1)[0]]
        # setattr(parent_module, name.split('.')[-1], new_layer)
        torch.cuda.empty_cache()
        print(f"Completed replacing layer: {name}")


def update_layers_by_transformer(model, linear_layers, args: Namespace):
    print(f"Current working directory: {os.getcwd()}")

    for layer_index, name, layer in linear_layers:
        if 'lm_head' in name:
            print(f"Ignoring layer: {name}")
            continue
        print(f"Attribute layer: {name}")

        layer_attribute_map = {
            "self_attn.q_proj": 1,
            "self_attn.k_proj": 2,
            "self_attn.v_proj": 3,
            "self_attn.o_proj": 4,
            "mlp.gate_proj": 5,
            "mlp.up_proj": 6,
            "mlp.down_proj": 7,
        }
        attribute_name = name.split('.')[-2] + '.' + name.split('.')[-1]  # self_attn.q_proj
        layer_attribute = layer_attribute_map.get(attribute_name, 0)

        if layer_index == 0:
        # if layer_index == 0 or layer_index == 4:
            print(f"Replacing layer at index {layer_index} with attribute {layer_attribute}: {name}")
        else:
            continue  # Skip non-target layers

        layer_shape = layer.weight.shape
        w = layer.weight.clone().detach()
        print("Original layer precision", w.dtype)

        if w.is_cuda:
            w = w.cpu()
        if w.dtype == torch.bfloat16:
            w = w.to(torch.float16)
        # The following code is the compression process. Change compression parameters here.

        r = args.r
        d = args.d
        g = args.g
        q = args.q
        b = args.b
        c = args.c
        dense_matrix, e = adjust_matrix(w, d)

        output_directory = '.../compress_weight'
        # Define file path
        file_path = f'{output_directory}/layer{layer_index}_{layer_attribute}.pt'

        # Check if file exists
        if os.path.exists(file_path):
            # Load tensors from file
            loaded_tensor_dict = torch.load(file_path)
            # Extract tensors
            compress_code_tensor = loaded_tensor_dict['compress_code_tensor']
            rq_centroids_tensor = loaded_tensor_dict['rq_centroids_tensor']
            original_shape_e = loaded_tensor_dict['original_shape_e']
            print("Tensors loaded successfully.")
        else:
            # File does not exist, perform compression
            print("File not found, performing compression...")

            # Generate compressed matrix
            compress_code_tensor, rq_centroids_tensor, original_shape_e = generate_compressed_dense_matrix_g(
                dense_matrix, q, r, 256, b, d, g
            )

            # Create tensor dictionary
            tensor_dict = {
                'compress_code_tensor': compress_code_tensor,
                'rq_centroids_tensor': rq_centroids_tensor,
                'original_shape_e': original_shape_e
            }

            # Save compressed tensors
            torch.save(tensor_dict, file_path)
            print("Compressed tensors saved successfully.")

        problem_channels = [378, 491]
        extracted_problem_channels = extract_weight_params(w, problem_channels)
        random_channels = [102, 104]
        extracted_random_channels = extract_weight_params(w, random_channels)

        # Decompression
        r_w = generate_decompressed_dense_matrix_g(compress_code_tensor, rq_centroids_tensor, original_shape_e, e, layer_shape, d)

        w = w.cuda()

        # 3. Decompress data

        if w.is_cuda:
            r_w = r_w.cuda()  # Ensure on the same device

        mse = torch.mean((w - r_w) ** 2)
        print("Mean squared error before and after compression:", mse.item())
        # TODO: Calculate mean squared error before and after compression
        r_w = r_w.to(torch.bfloat16)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        max_diff = torch.max(torch.abs(w.to(device) - r_w.to(device)))
        print(f"Maximum difference between original and decompressed weights: {max_diff.item()}")

        bias = layer.bias if layer.bias is not None else None
        in_features = layer.in_features
        out_features = layer.out_features

        # Extract outlier coordinates
        topk_weight_indices = torch.load("topk_weight_indices.pt")  # shape ([2048, 2])
        print("Shape:", topk_weight_indices.shape)  # Print tensor shape
        print("topk_weight_indices dtype:", topk_weight_indices.dtype)  # Should be int64
        # TODO: Convert this data into indices by combining the two coordinates with w's shape to get a linear index,
        # use torch.int32, then extract corresponding values from w and record as a 1D tensor(column vector), positions correspond to topk_weight_indices

        # Move topk_weight_indices to the device of w
        topk_weight_indices = topk_weight_indices.to(w.device)

        rows = topk_weight_indices[:, 0]
        cols = topk_weight_indices[:, 1]

        # Shape of w
        num_rows, num_cols = w.shape

        # Calculate linear indices, convert to int32 type
        linear_indices = rows * num_cols + cols
        linear_indices = linear_indices.to(torch.int32)

        # Extract corresponding elements
        w_flat = w.view(-1)
        extracted_values = w_flat[linear_indices]

        # Convert to column vector
        extracted_values = extracted_values.unsqueeze(1)  # shape = [2048, 1], dtype fp16
        # TODO: Combine linear_indices (int32) and extracted_values (fp16) to generate a new tensor with two columns:
        # first column linear_indices, second column extracted_values

        outliers = {
            'linear_indices': linear_indices,  # shape [2048, 1], dtype int32
            'extracted_values': extracted_values  # shape [2048, 1], dtype float16
        }

        # TODO: Add outliers into here. How to define this class?

        new_layer = BitMatrixLayer(
            extracted_random_channels,
            rq_centroids_tensor,
            compress_code_tensor,
            original_shape_e,
            e,
            layer_shape,
            q,
            d,
            in_features,
            out_features,
            layer_index,
            g,
            use_checkpoint=False,
            bias=bias,
            outliers=outliers
        )

        # Start fine tuning
        if layer_attribute == 7:

            print("Starting fine tuning")
            # Calculate sequence number
            sequence = (layer_index) * 7 + layer_attribute - 1
            print(f"Current sequence number: {sequence}")
            if layer_attribute == 99 and layer_index == 4:

                dataset = "wikitext-2"
                model_path = '.../model/gemma2_9B'
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
                data = get_dataset_by_length(dataset, tokenizer)
                # TODO: Select data indexed in .../greater_than_0.03.txt
                filename = f".../greater_than_0.03.txt"
                # Open the file and read each line
                with open(filename, 'r') as f:
                    lines = f.readlines()
                # Remove newline character and convert to list of integers
                indexes = [int(line.strip()) for line in lines]
                # TODO: Only take data corresponding to indexes

                selected_tensors = []
                for i in indexes:
                    if 0 <= i < len(data):
                        sample = data[i]
                        selected_tensors.append(sample)
                data = torch.stack(selected_tensors, dim=0)  # Concatenate along dim 0

                print("Shape:", data.shape)
                seq_len = data.size(1)
                inps, forward_args = get_inps_llama_by_linear(model, data, seq_len, device, False, 34)
            else:
                inps_file = f'.../activation/linear_inps/inps_linear_{sequence}.pt'
                inps = torch.load(inps_file)
            outs_file = f'.../activation/linear_out/outs_linear_{sequence}.pt'
            outs = torch.load(outs_file)

            # Calculate number of samples to select
            num_samples = inps.size(0)
            num_samples_to_select = num_samples // 10

            # Randomly select sample indices
            indices = random.sample(range(num_samples), num_samples_to_select)

            # Select inputs and outputs by indices
            selected_inps = inps[indices]
            selected_outs = outs[indices]

            topk_weight_indices = compensation_groupwise(
                layer=new_layer,
                train_inps=inps,
                train_outs=outs,
                args=args,
                valid_inps=inps,
                valid_outs=outs,
                verbose=True  # Optionally print detailed info
            )
            # TODO: Save topk_weight_indices to file and provide extraction method

            torch.save(topk_weight_indices, "topk_weight_indices.pt")

        # 4. Save to file then extract 2048 large errors in compress layer and assign to quantized layer
        parent_module = dict(model.named_modules())[name.rsplit('.', 1)[0]]
        setattr(parent_module, name.split('.')[-1], new_layer)
        torch.cuda.empty_cache()
        print(f"Completed replacing layer: {name}")

        # After quantizing all linear layers, fine tune specified transformer layers
        if layer_attribute == 99:
            print("All linear layers quantized, start fine tuning specified transformer layers.")

            # Load input and output data
            inps_file = f'.../activation/transformer_inps/inps_transformer_{layer_index}.pt'
            inps = torch.load(inps_file)
            outs_file = f'.../activation/transformer_out/outs_transformer_{layer_index}.pt'
            outs = torch.load(outs_file)

            # Assume the transformer layer to fine tune is named `transformer_layer_name`
            transformer_layer_name = f"model.layers.{layer_index}"

            # Get the layer to fine tune
            transformer_layer = dict(model.named_modules())[transformer_layer_name]

            # Perform fine tuning
            transformer_layer = update_groupwise(
                layer=transformer_layer,  # Fine tune only specified layer
                train_inps=inps,
                train_outs=outs,
                args=args,
                valid_inps=inps,  # Can set validation set accordingly
                valid_outs=outs,
                verbose=True
            )

            # Replace the layer in model
            parent_module = dict(model.named_modules())[transformer_layer_name.rsplit('.', 1)[0]]
            setattr(parent_module, transformer_layer_name.split('.')[-1], transformer_layer)

def quantize_and_compress(layer_index, name, layer, best_params_dir, log_dir):
    # Check layer name, skip if lm_head
    if 'lm_head' in name:
        print(f"Ignoring layer: {name}")
        return None

    print(f"Attribute layer: {name}")

    layer_attribute_map = {
        "self_attn.q_proj": 1,
        "self_attn.k_proj": 2,
        "self_attn.v_proj": 3,
        "self_attn.o_proj": 4,
        "mlp.gate_proj": 5,
        "mlp.up_proj": 6,
        "mlp.down_proj": 7,
    }

    attribute_name = name.split('.')[-2] + '.' + name.split('.')[-1]  # self_attn.q_proj
    layer_attribute = layer_attribute_map.get(attribute_name, 0)

    layer_shape = layer.weight.shape
    w = layer.weight.clone().detach()
    original_mean = torch.mean(w)
    original_median = torch.median(w)

    if w.is_cuda:
        w = w.cpu()
    if w.dtype == torch.bfloat16:
        w = w.to(torch.float16)

    best_params_filename = os.path.join(best_params_dir, f'best_params_layer_{layer_index}_attr_{layer_attribute}.json')
    log_data_filename = os.path.join(log_dir, f'log_data_layer_{layer_index}_attr_{layer_attribute}.json')
    print(f"Current CUDA memory usage: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
    try:
        q_best, r_best, c_best, b_best, d_best = optimize_compression_params(w, layer_index, layer_attribute, best_params_filename, log_data_filename)
        rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, q_best, r_best, c_best, b_best, d_best)

        # Check MSE and regenerate compressed matrix if needed
        print("Median:", original_median)
        print("Mean:", original_mean)

        if layer_index < 6:
            if mse_temp > original_mean * 10000:
                torch.cuda.empty_cache()  # Clear cache
                if layer_attribute in [2]:
                    rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 16)
                elif layer_attribute in [3]:
                    rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 16)
                elif layer_attribute in [1, 4]:
                    rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 32)
                elif layer_attribute in [5, 6, 7]:
                    rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 32)

            if mse_temp > original_mean * 10000:
                # torch.cuda.empty_cache()  # Clear cache
                print(f"Layer index: {q_best}, Layer attribute: {layer_attribute}, MSE: {mse_temp} - MSE not reduced first time")
                # rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 16, 256, 1, 32)

            if mse_temp > original_mean * 10000:
                print(f"Layer index: {q_best}, Layer attribute: {layer_attribute}, MSE: {mse_temp} - MSE not reduced second time")
        else:
            if mse_temp > original_mean * 1000:
                torch.cuda.empty_cache()  # Clear cache
                if layer_attribute in [2]:
                    rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 16)
                elif layer_attribute in [3]:
                    rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 16)
                elif layer_attribute in [1, 4]:
                    rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 32)
                elif layer_attribute in [5, 6, 7]:
                    rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 32)

            if mse_temp > original_mean * 10000:
                # torch.cuda.empty_cache()  # Clear cache
                print(f"Layer index: {q_best}, Layer attribute: {layer_attribute}, MSE: {mse_temp} - MSE not reduced first time")
                # rq_centroids, pca_layer, original_shape, q, mse_temp = generate_compressed_matrix(w, 0, 20, 256, 1, 32)

            if mse_temp > original_mean * 10000:
                print(f"Layer index: {q_best}, Layer attribute: {layer_attribute}, MSE: {mse_temp} - MSE not reduced second time")

        # Decompress data
        r_w = generate_decompressed_matrix(rq_centroids, pca_layer, original_shape, q)
        print(f"Current CUDA memory usage: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
        max_value = pca_layer.max()
        min_value = pca_layer.min()
        print(f"Max value of pca_layer: {max_value.item()}")
        print(f"Min value of pca_layer: {min_value.item()}")

        r_w = r_w.to(torch.bfloat16)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        max_diff = torch.max(torch.abs(w.to(device) - r_w.to(device)))
        print(f"Layer {name} processing finished, max difference: {max_diff.item()}")
        # Delete r_w variable and clear cache
        del r_w
        torch.cuda.empty_cache()
        return (name, layer, rq_centroids, pca_layer, original_shape, q, "Success")

    except Exception as e:
        print(f"Error occurred when processing layer {name}: {e}")
        return (name, layer, None, None, None, None, "Failure")


def update_layers_parallel(model, linear_layers):
    print(f"Current working directory: {os.getcwd()}")
    log_dir = 'params_log'
    best_params_dir = 'params_best'
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(best_params_dir, exist_ok=True)

    with ThreadPoolExecutor(max_workers=2) as executor:
        futures = {
            executor.submit(quantize_and_compress, layer_index, name, layer, best_params_dir, log_dir): (layer_index, name)
            for layer_index, name, layer in linear_layers
        }

        results = []
        for future in as_completed(futures):
            layer_index, layer_name = futures[future]
            try:
                result = future.result()
                if result is not None:
                    results.append(result)
                    # Print each thread's processing status
                    print(f"Layer {layer_name} processing status: {result[-1]}")  # Print status info
            except Exception as e:
                print(f"Error occurred during quantization of layer {layer_name}: {e}")

    # Replace layers in the model
    for name, layer, rq_centroids, pca_layer, original_shape, q, _ in results:
        if rq_centroids is None:  # Skip failed layers
            continue

        bias = layer.bias if layer.bias is not None else None
        in_features = layer.in_features
        out_features = layer.out_features
        new_layer = BitMatrixLayer(rq_centroids, pca_layer, original_shape, q, in_features, out_features, layer_index, use_checkpoint=False, bias=bias)

        # Replace layer
        parent_module = dict(model.named_modules())[name.rsplit('.', 1)[0]]
        setattr(parent_module, name.split('.')[-1], new_layer)
        torch.cuda.empty_cache()
        print(f"Completed replacement of layer: {name}")


class BitMatrixLayer(nn.Module):
    """
    Define the replacement layer computation for bit matrix operation in neural networks.

    Parameters:
    rq_centroids : torch.Tensor
        Codebook with shape (r, c, d), where r is the number of quantizers, c is the number of centers per quantizer, d is the feature dimension.
    pca_layer : dict
        Contains PCA layer info used for initializing weights.
    original_shape : tuple
        Shape of original matrix, usually used for decoding.
    q : int
        Scaling factor used to adjust matrix values.
    in_features : int
        Number of input features.
    out_features : int
        Number of output features.
    layer_index : int
        Current layer index for unique identification.
    bias : Optional[nn.Parameter]
        Optional bias parameter, default None.
    use_checkpoint : bool
        Whether to use checkpointing to save memory.
    """

    def __init__(self, i_channel: list, rq_centroids_tensor, compress_code_tensor, original_shape_e: tuple, e: dict, original_shape: tuple, q: int, d: int,
                 in_features: int, out_features: int, layer_index: int, group_num: int,
                 bias: Optional[nn.Parameter] = None, outliers: Optional[Dict[str, torch.Tensor]] = None, use_checkpoint: bool = False):
        super(BitMatrixLayer, self).__init__()

        # Residual vector
        self.e = e
        self.eshape = original_shape_e
        # Abnormal channels
        self.i_channel = i_channel
        self.bias = bias
        self.q = q
        self.d = d
        # self.rq_centroids = rq_centroids_tensor
        self.out_features = out_features  # output feature count
        self.in_features = in_features  # input feature count
        self.use_checkpoint = use_checkpoint  # whether to use checkpointing
        self.group_num = group_num

        self.layer_index = layer_index
        self.original_shape = original_shape
        self.outliers = outliers  # Dictionary of outliers, including linear_indices and extracted_values

        # Initialize weight parameters
        # self.rq_centroids = nn.Parameter(rq_centroids_tensor, requires_grad=True)
        self.rq_centroids = nn.Parameter(rq_centroids_tensor.float(), requires_grad=True)  # for fine-tuning
        self.q_weight = nn.Parameter(compress_code_tensor, requires_grad=False)

    def forward(self, input: torch.Tensor):
        if self.use_checkpoint and torch.is_grad_enabled():
            return checkpoint(self._forward, input, use_reentrant=False)
        return self._forward(input)

    def _forward(self, input: torch.Tensor):

        weight = generate_decompressed_dense_matrix_g(self.q_weight, self.rq_centroids.to(torch.float16), self.eshape, self.e, self.original_shape, self.d)
        # print(f"Output shape: {weight.shape}, output dtype: {weight.dtype}")
        # TODO: Replace values at corresponding locations in weight with those in outliers
        # if self.outliers is not None:
        #     linear_indices = self.outliers['linear_indices']
        #     extracted_values = self.outliers['extracted_values']
        #     # Ensure linear_indices and extracted_values are on same device
        #     linear_indices = linear_indices.to(weight.device)
        #     extracted_values = extracted_values.to(weight.device)
        #     # Place extracted values into the specified positions in the weight
        #     weight_flat = weight.view(-1)  # flatten to 1D
        #     weight_flat[linear_indices] = extracted_values.squeeze()
        #     weight = weight_flat.view(weight.shape)  # reshape back to original shape
        # print(f"Weight shape: {weight.shape}, weight dtype: {weight.dtype}")
        # print(f"Input shape: {input.shape}, input dtype: {input.dtype}")

        # weight = weight.to(torch.float16).to('cuda')  # convert to float16 and move to GPU

        # Assume input channels correspond to columns of weight matrix
        # replace_params_in_weight(weight, self.i_channel, dim=1)
        if input.dtype != weight.dtype:
            input = input.to(weight.dtype)  # convert input dtype to weight dtype

        # Perform linear transformation using weight
        output = F.linear(input, weight, self.bias)
        # print(f"Output shape: {output.shape}, output dtype: {output.dtype}")
        if torch.isnan(output).any() or torch.isinf(output).any():
            # Print input and weight info
            print(f"Input shape: {input.shape}, dtype: {input.dtype}, content: {input}")
            print(f"Weight shape: {weight.shape}, dtype: {weight.dtype}, content: {weight}")
            # print(f"Bias shape: {self.bias.shape}, dtype: {self.bias.dtype}, content: {self.bias}")

            # Optionally raise exception
            raise ValueError("Output contains NaN or Inf.")

        # Release weight immediately after use
        del weight  # delete weight tensor
        torch.cuda.empty_cache()  # clear unused cache (optional)

        return output

import resource

def print_memory_usage():
    # Print system memory usage
    mem_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024  # Convert to MB
    print(f"Process Memory Usage: {mem_usage:.2f} MB")

def main():
    torch.set_num_threads(min(16, torch.get_num_threads()))
    torch.backends.cudnn.allow_tf32 = False
    torch.backends.cuda.matmul.allow_tf32 = False

    parser = argparse.ArgumentParser(add_help=True)

    # Fine-tuning related parameters
    parser.add_argument(
        "--update_max_epochs",
        type=int,
        default=10,
        help="Run this many passes over training data when doing updating; No updating if set to 0."
    )
    parser.add_argument(
        "--update_early_stop",
        type=int,
        default=5,
        help="Terminate updating if loss doesn't improve after this number of epochs."
    )
    parser.add_argument(
        "--update_lr",
        type=float,
        default=1e-5,
        help="Updating learning rate."
    )
    parser.add_argument(
        "--update_batch_size",
        type=int,
        default=1,
        help="(Updating only) train on batches of this many sequences, globally across all GPUs."
    )
    parser.add_argument(
        "--update_adam_beta1",
        type=float,
        default=0.9,
        help="Updating adam_beta1."
    )
    parser.add_argument(
        "--update_adam_beta2",
        type=float,
        default=0.95,
        help="Updating adam_beta2."
    )
    parser.add_argument("--update_keep_best", action="store_true", help="Keep the best model parameters during updating.")
    parser.add_argument(
        "--local_batch_size",
        type=int,
        default=4,
        help="(Updating only) Per-device and per-forward-pass batch size used to accumulate global --batch_size."
    )
    parser.add_argument(
        "--val_size",
        type=int,
        default=128,
        help="Number of validation sequences."
    )
    parser.add_argument(
        "--print_frequency",
        type=int,
        default=10,
        help="Print Adam progress after each print_frequency updates."
    )

    # Add new parameters
    parser.add_argument(
        "--q",
        type=int,
        default=0,  # Set default as needed
        help="Description for parameter q."
    )

    parser.add_argument(
        "--r",
        type=int,
        default=4,  # Set default as needed
        help="Description for parameter r."
    )

    parser.add_argument(
        "--c",
        type=int,
        default=1024,  # Set default as needed
        help="Description for parameter c."
    )

    parser.add_argument(
        "--b",
        type=int,
        default=1,  # Set default as needed
        help="Description for parameter b."
    )

    parser.add_argument(
        "--d",
        type=int,
        default=10,  # Set default as needed
        help="Description for parameter d."
    )
    parser.add_argument(
        "--g",
        type=int,
        default=256,  # Set default as needed
        help="Description for parameter g."
    )

    parser.add_argument('--offload_activations', type=bool, default=False, help='Offload activations to CPU')

    # Parse arguments
    args = parser.parse_args()
    args.devices = [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())]

    # Load model
    print("Loading model")
    model_path = '.../model/gemma2_9B'
    model = load_safetensors_model(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
    linear_layers = get_Linears(model)

    # Evaluation data
    input_text = "Hello, how are you?"
    dataset = "wikitext-2"
    # dataset = "ptb_text_only"

    # Evaluate perplexity before quantization
    # print("Evaluating fp16 before quantization:")
    # eval(model, tokenizer, input_text, dataset)

    # Update model layers and save
    print("Start replacing model layers")
    update_layers(model, linear_layers, args)
    # update_layers_by_transformer(model, linear_layers, args)

    print("Post-quantization evaluation")
    # TODO: Print memory usage here
    # Print memory usage before evaluation
    print("Before evaluation:")
    print_memory_usage()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    eval(model, tokenizer, input_text, dataset)


new_directory = os.getcwd()

try:
    # Change current working directory
    os.chdir(new_directory)
    print(f"Current working directory changed to: {os.getcwd()}")
except FileNotFoundError:
    print(f"Directory '{new_directory}' does not exist, please check the path.")
except PermissionError:
    print(f"No permission to access directory '{new_directory}'.")
except Exception as e:
    print(f"Error occurred when changing directory: {e}")
torch.cuda.empty_cache()

# Configure logging
log_file = 'gemma_bit.log'
logging.basicConfig(
    level=logging.INFO,  # Set log level
    format='%(asctime)s - %(levelname)s - %(message)s',  # Log format
    handlers=[
        logging.FileHandler(log_file),  # Write logs to file
        logging.StreamHandler(sys.stdout)  # Output to console simultaneously
    ]
)

# Redirect print() to log
class Logger(object):
    def __init__(self):
        self.terminal = sys.stdout
        self.log = open(log_file, "a")  # Open log file in append mode

    def write(self, message):
        self.terminal.write(message)  # Output to console
        self.log.write(message)  # Output to log file

    def flush(self):
        pass  # Method can be used to flush buffer

# Redirect sys.stdout to Logger
sys.stdout = Logger()

if __name__ == "__main__":
    fix_random_seed(42)
    main()

