import os
# Set environment variables for CUDA and Hugging Face endpoint.
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import argparse
import json
import logging
import os
import random
import torch
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
import lm_eval
from lm_eval import utils as lm_eval_utils
from lm_eval.api.registry import ALL_TASKS
from lm_eval.models.huggingface import HFLM
from lm_eval.tasks import initialize_tasks
# Note: HFLM is imported twice, which is redundant.
from lm_eval.models.huggingface import HFLM
import utils

# Set random seeds for reproducibility.
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
set_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)


# The SyncLogger class is used to record the pruning process
class SyncLogger:
    def __init__(self, filename):
        # During initialization, try to load the dictionary from the specified file
        self.filename = filename
        self.data = {}
        self._load()

    def _load(self):
        # Load the dictionary from the file if it exists
        if os.path.exists(self.filename):
            with open(self.filename, 'r') as f:
                self.data = json.load(f)

    def update(self):
        # Update the dictionary and write the latest dictionary into the file
        with open(self.filename, 'w') as f:
            json.dump(self.data, f, indent=4)


def main():
    # Initialize argument parser.
    parser = argparse.ArgumentParser()
    # Add arguments for model path, output path, weight reorder flag, and batch size.
    parser.add_argument(
        "--model_path",
        type=str,
        default="pruned_model", 

        help="Path to the pre-trained model."
    )
    parser.add_argument(
        "--output_path",
        type=str,
        default="eval_result",
        help="Directory to save the evaluation results."
    )
    parser.add_argument(
        "--weight_reorder",
        action="store_true",
        help="Flag to indicate whether to perform weight reorder."
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=16,
        help="Batch size for evaluation."
    )

    # Parse arguments.
    args = parser.parse_args()
    model_path = args.model_path
    base_output_path = args.output_path
    # Construct output path based on model name.
    output_path = os.path.join(base_output_path, model_path.split('/')[-1])
    weight_reorder = args.weight_reorder
    # Create output directory if it doesn't exist.
    os.makedirs(output_path, exist_ok=True)


    # Evaluation setup.
    batch_size = args.batch_size
    # Initialize lm_eval tasks.
    initialize_tasks()

    # Set seed again (already set globally, but good for local context).
    set_seed(42)
    # Define log file path.
    log_file_path = os.path.join(output_path, "log.json")
    # Initialize SyncLogger for recording results.
    logger = SyncLogger(log_file_path)

    # Save command line arguments to the log.
    logger.data["args"] = vars(args)

    # Load model and tokenizer.
    # Model is loaded with device_map="auto" for automatic device placement,
    # trust_remote_code=True for custom model code,
    # torch_dtype="bfloat16" for mixed-precision,
    # and use_cache=False as caching is not typically needed for evaluation.
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",
        trust_remote_code=True,
        torch_dtype="bfloat16",
        use_cache=False
    ).eval() # Set model to evaluation mode.

    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    logger.data["weight_reorder"] = False

    # Perform weight reordering if specified.
    if weight_reorder:
        logger.data["weight_reorder"] = True
        for layer in utils.get_layers(model): # Iterate through model layers.
            # Reorder weights in attention blocks.
            utils.reorder_in_attn_block(getattr(layer, utils.get_attn_key(model)), model=model)
            # Reorder weights in MLP blocks.
            utils.reorder_in_mlp_block(getattr(layer, utils.get_mlp_key(model)))
    
    logging.info(f"Start evaluation...")

    # Evaluate on wikitext2 dataset.
    wiki_dataset = utils.get_dataset("wikitext2")
    wiki_test_dataset = wiki_dataset["test"] # Get the test split.
    # Prepare DataLoader for wikitext2.
    wiki_test_loader = utils.prepare_test_dataloader(
        name="wikitext2",
        dataset=wiki_test_dataset,
        tokenizer=tokenizer,
        seqlen=2048, # Sequence length for evaluation.
        batch_size=8  # Batch size for perplexity calculation.
    )
    # Calculate perplexity on wikitext2.
    wiki_dataset_ppl = utils.evaluate_ppl(
        model=model,
        dataloader=wiki_test_loader,
        pad_token_id=model.config.eos_token_id, # Use EOS token ID for padding.
    )
    wiki_dataset_ppl = round(wiki_dataset_ppl, 2) # Round to two decimal places.
    logging.info(f'wikitext2 PPL: {wiki_dataset_ppl}')
    
    # Evaluate on ptb (Penn Treebank) dataset.
    ptb_dataset = utils.get_dataset("ptb")
    ptb_test_dataset = ptb_dataset["test"]
    # Prepare DataLoader for ptb.
    ptb_test_loader = utils.prepare_test_dataloader(
        name="ptb",
        dataset=ptb_test_dataset,
        tokenizer=tokenizer,
        seqlen=2048,
        batch_size=8
    )   
    # Calculate perplexity on ptb.
    ptb_dataset_ppl = utils.evaluate_ppl(
        model=model,
        dataloader=ptb_test_loader,
        pad_token_id=model.config.eos_token_id,
    )
    ptb_dataset_ppl = round(ptb_dataset_ppl, 2)
    logging.info(f'ptb PPL: {ptb_dataset_ppl}')

    # Evaluate on selected downstream tasks using lm_eval harness.
    # Wrap the model and tokenizer with HFLM for lm_eval compatibility.
    hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=batch_size)

    # Define a list of tasks for evaluation.
    task_names = ["boolq", "piqa", "winogrande", "hellaswag", "arc_easy", "arc_challenge", "openbookqa"]
    # Match task names against all available tasks in lm_eval.
    task_names = lm_eval_utils.pattern_match(task_names, ALL_TASKS)
    logging.info(f"Selected Tasks: {task_names}")

    # Run simple evaluation using lm_eval.
    # num_fewshot=0 indicates zero-shot evaluation.
    results = lm_eval.simple_evaluate(hflm, tasks=task_names, num_fewshot=0, batch_size=batch_size, log_samples=False)['results']

    # Extract accuracy metrics from results.
    # Uses 'acc_norm,none' if available (normalized accuracy), otherwise 'acc,none'.
    metric_vals = {task: round(result.get('acc_norm,none', result['acc,none']), 4) * 100 for task, result in
                    results.items()}
    logging.info(json.dumps(metric_vals, indent=4))

    # Function to calculate average accuracy across tasks.
    def calculate_avg_accuracy(task_names, results):
        n_tasks = len(task_names)
        acc_cumul = sum(result.get('acc_norm,none', result['acc,none']) for task, result in results.items())
        return round(acc_cumul / n_tasks, 4) * 100

    acc_avg = calculate_avg_accuracy(task_names, results)
    logging.info(f"Average accuracy across tasks: {acc_avg}")

    # Save evaluation results to a JSON file.
    overall_results = {
        "ppl_wikitext2": wiki_dataset_ppl,
        "ppl_ptb": ptb_dataset_ppl,
        "5cs_acc_avg": acc_avg, # "5cs" likely refers to the 5 common sense reasoning tasks (piqa, winogrande, hellaswag, arc_easy, arc_challenge)
        **metric_vals # Unpack individual task accuracies.
    }
    # Construct result file name based on model path.
    result_file_name = model_path.split('/')[-1]
    eval_result_path = os.path.join(output_path, f"eval.result.{result_file_name}.json")
    with open(eval_result_path, "w") as f:
        json.dump(overall_results, f, indent=4)


    # Log overall results using SyncLogger.
    logger.data[f"evaluation_{result_file_name}"] = overall_results
    logger.update()

if __name__ == "__main__":
    main()

