import os
import sys
# Set environment variables for CUDA and Hugging Face endpoint.
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import random
import numpy as np
import argparse
import json
import logging
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import set_seed
import utils
# Set random seeds for reproducibility.
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
set_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)


def main():
    # Initialize argument parser for pruning configuration.
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_path",
        type=str,
        default="meta-llama/Llama-2-7b-hf",
        help="Path to the pre-trained model."
    )
    parser.add_argument(
        "--output_path",
        type=str,
        default="pruned_model",
        help="Directory to save the compressed model."
    )
    parser.add_argument(
        "--weight_reorder",
        action="store_true",
        help="Flag to indicate whether to perform weight reorder."
    )
    parser.add_argument(
        "--pruned_model_config_file",
        type=str,
        default="pruning_config/pruning_block_config_0.8.json",
        help="Path to the pruned model configuration file."
    )

    args = parser.parse_args()
    model_path = args.model_path
    output_path = args.output_path
    weight_reorder = args.weight_reorder
    # Create output directory if it doesn't exist
    os.makedirs(output_path, exist_ok=True)

    pruned_model_config_file = args.pruned_model_config_file

    # Load model and tokenizer
    # Using device_map={"": 0} to put the model on the first GPU.
    # trust_remote_code=True allows execution of custom model code.
    # torch_dtype="float16" uses half precision for memory efficiency.
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map={"": 0},
        trust_remote_code=True,
        torch_dtype="float16",
    )
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

    # Optionally reorder weights to maximize importance of retained weights
    if weight_reorder:
        # Iterate through each layer of the model and reorder weights
        for layer in utils.get_layers(model):
            # Reorder weights in attention blocks
            utils.reorder_in_attn_block(getattr(layer, utils.get_attn_key(model)), model=model)
            # Reorder weights in MLP blocks
            utils.reorder_in_mlp_block(getattr(layer, utils.get_mlp_key(model)))
    
    # Load pruning config that specifies which blocks to prune
    with open(pruned_model_config_file, "r") as f:
        pruned_config = json.load(f)
    logging.info(f"Detect a pruned model config: {pruned_config}")
    state_dict = model.state_dict()

    def get_groups(model, key):
        """
        Group linear layers by their parent module (attention or MLP).
        
        Args:
            model: The model to extract layers from.
            key: The key identifying the parent module (e.g., 'self_attn' or 'mlp').
            
        Returns:
            List of dictionaries mapping layer names to their modules.
        """
        groups = []
        for layer in utils.get_layers(model):
            modules = getattr(layer, key)
            groups.append({name: module for name, module in modules.named_children() if isinstance(module, torch.nn.Linear)})
        return groups

    def get_pruned_weights(groups, pruned_channels):
        """
        Get pruned weights for each module based on the pruning configuration.
        
        Args:
            groups: List of dictionaries mapping layer names to their modules.
            pruned_channels: Dictionary mapping group indices to pruned channel widths.
            
        Returns:
            Dictionary mapping modules to their pruned weights.
        """
        # Get model architecture parameters for attention mechanism
        num_key_value_heads = utils.get_num_kv_heads(model)
        num_attention_heads = utils.get_num_attention_heads(model)
        num_key_value_groups = num_attention_heads // num_key_value_heads
        k_proj_key, v_proj_key = utils.get_k_key(model), utils.get_v_key(model)

        module_to_weight = {}
        # Process each group (layer) based on pruned channels
        for group_idx, value in pruned_channels.items():
            group = groups[int(group_idx)]
            for name, module in group.items():
                # Handle dependency groups differently (typically output projections)
                if name in utils.DEPENDENCY_GROUPS:
                    # Prune input dimension (dimension 1) for dependent modules
                    module_to_weight[module] = module.weight[:, :value]
                else:
                    # Special handling for key and value projections due to grouped attention
                    if name in [k_proj_key, v_proj_key]:
                        module_to_weight[module] = module.weight[:value // num_key_value_groups]
                    else:
                        # Prune output dimension (dimension 0) for independent modules
                        module_to_weight[module] = module.weight[:value]
        return module_to_weight

    # Get MLP and attention layer groups
    mlp_groups = get_groups(model, utils.get_mlp_key(model))
    attn_groups = get_groups(model, utils.get_attn_key(model))
    module_to_weight = {}
    
    # Apply width pruning if specified in the config
    if pruned_config.get("pruned_attn_width"):
        module_to_weight.update(get_pruned_weights(attn_groups, pruned_config["pruned_attn_width"]))
    if pruned_config.get("pruned_mlp_width"):
        module_to_weight.update(get_pruned_weights(mlp_groups, pruned_config["pruned_mlp_width"]))

    # Update model state dict with pruned weights
    linear_modules = {name: module for name, module in model.named_modules() if isinstance(module, torch.nn.Linear)}
    for name, module in linear_modules.items():
        if module in module_to_weight:
            sd_weight_key = name + ".weight"
            assert sd_weight_key in state_dict
            pruned_weight = module_to_weight[module]
            state_dict[sd_weight_key] = pruned_weight.clone()
            # Update bias if present
            sd_bias_key = name + ".bias"
            if sd_bias_key in state_dict:
                # Prune bias to match the pruned weight output dimension
                state_dict[sd_bias_key] = state_dict[sd_bias_key][:pruned_weight.size(0)].clone()
    
    def prune_modules(state_dict, idx, key):
        """
        Remove entire modules (blocks) from the state dict.
        
        Args:
            state_dict: The model's state dict to modify.
            idx: The layer index to prune.
            key: The module key (e.g., 'self_attn' or 'mlp').
            
        This function removes all parameters associated with the specified module.
        """
        target = f".{str(idx)}.{key}"
        remove_key = []
        # Identify all keys in the state dict that belong to this module
        for name, module in state_dict.items():
            if target in name:
                remove_key.append(name)
        # Remove all identified keys from the state dict
        for key in remove_key:
            del state_dict[key]

    # Apply block pruning by removing entire attention or MLP blocks
    if pruned_config.get("pruned_attn_idx"):
        pruned_attn_idx = pruned_config["pruned_attn_idx"]
        for idx in pruned_attn_idx:
            # Remove attention blocks at specified indices
            prune_modules(state_dict, idx, utils.get_attn_key(model))
    if pruned_config.get("pruned_mlp_idx"):
        pruned_mlp_idx = pruned_config["pruned_mlp_idx"]
        for idx in pruned_mlp_idx:
            # Remove MLP blocks at specified indices
            prune_modules(state_dict, idx, utils.get_mlp_key(model))

    # Save the pruned model and tokenizer
    model.save_pretrained(output_path, state_dict=state_dict)
    tokenizer.save_pretrained(output_path)


if __name__ == "__main__":
    main()
