import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import logging
import argparse
from argparse import Namespace
import yaml
import json
import os.path as osp

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from lm_eval.models.huggingface import HFLM
from lm_eval import evaluator
from lm_eval.utils import make_table

from model import (
    hack_pruned_mixtral_to_adapt,
    unhack_pruned_mixtral_to_adapt
)

logger = logging.getLogger(__name__)

def load_config_from_path(config_path: str):
    
    layer_experts = None
    opt_order = None
    prune_metric = None
    prune_ratio = None
    merge_stat = None
    merge_metric = None
    merge_ratio = None
    
    try:
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
            
            yaml_filename = config_path.split('/')[-1]
            if yaml_filename.endswith('.yaml'):
                yaml_filename = yaml_filename[:-5]  
            
            yaml_parts = yaml_filename.split('_')
            if len(yaml_parts) >= 8:  
                model_name = yaml_parts[0]
                opt_order = '_'.join(yaml_parts[1:3])  
                prune_metric = yaml_parts[3]
                prune_ratio = yaml_parts[4]
                merge_stat = yaml_parts[5]
                merge_metric = yaml_parts[6]
                merge_ratio = yaml_parts[7]
                logger.info(f"Parsed config filename: model={model_name}, opt_order={opt_order}, "
                          f"prune_metric={prune_metric}, prune_ratio={prune_ratio}, "
                          f"merge_stat={merge_stat}, merge_metric={merge_metric}, merge_ratio={merge_ratio}")
            
            if "step1_expert" in config:
                layer_experts = config["step1_expert"]
                logger.info(f'Loaded {len(layer_experts)} layer expert configurations from {config_path}')
                for i, r in enumerate(layer_experts):
                    logger.info(f'Layer {i}: {r} experts')
    except Exception as e:
        logger.error(f'Error loading config file: {e}')
    return layer_experts, opt_order, prune_metric, prune_ratio, merge_stat, merge_metric, merge_ratio


def parse_args():
    parser = argparse.ArgumentParser(description='Evaluation script for pruned models')
    parser.add_argument('--pruned_model_path', type=str, 
                        default="/data1/ldz/HC-MoE/Expert_Sparsity/output/mixtral_layerwise_pruning_c4_prune_only_combined_0.5_std_weight-cosine_0.5",
                        help='Path to pruned model')
    parser.add_argument('--original_model_path', type=str, 
                        default="/data/Mixtral-8x7B-v0.1",
                        help='Path to original model used for initialization')
    parser.add_argument('--config_path', type=str, 
                        default="/data1/ldz/HC-MoE/results/opt/prune_only/mixtral_prune_only_combined_0.5_max_output-cka_0.5.yaml",
                        help='Path to YAML config file containing experts configuration')
    parser.add_argument('--prune_method', type=str, default='',
                        help='Method used for pruning')
    parser.add_argument('--eval_save_path', type=str, default=None,
                        help='Custom name for saving evaluation results')
    parser.add_argument('--batch_size', type=int, default=8,
                        help='Batch size for model inference')
    parser.add_argument('--seed', type=int, default=1234,
                        help='Random seed for reproduction')
    parser.add_argument('--use_flash_attention_2', action='store_true', default=True,
                        help='If set, Flash Attention 2 will be used')
    parser.add_argument('--expert_config', type=str, default='step1_expert', 
                        choices=['step1_expert', 'step2_expert'],
                        help='Which expert configuration to use from config file')
    parser.add_argument('--eval_ppl', action='store_true', default=False,
                        help='Whether to evaluate perplexity')
    parser.add_argument('--ppl_dataset', type=str, default='wikitext',
                        choices=['wikitext', 'c4'],
                        help='Dataset to use for perplexity evaluation')
    
    return parser.parse_args()


def main(args: Namespace):
    logger.info(f'Arguments: {args}')
    
    if args.original_model_path.endswith('/'):
        args.original_model_path = args.original_model_path[:-1]
    model_name = args.original_model_path.split('/')[-1]
    
    layer_experts, opt_order, prune_metric, prune_ratio, merge_stat, merge_metric, merge_ratio = load_config_from_path(args.config_path)

    
    if args.eval_save_path is None:
        eval_save_path = os.path.join(
            f'./results/opt/{opt_order}', 
            f'{model_name}_{args.prune_method}_{prune_metric}_{prune_ratio}_{merge_stat}_{merge_metric}_{merge_ratio}.txt'
        )
    else:
        eval_save_path = os.path.join(
            f'./results/opt/{opt_order}', 
            f'{args.eval_save_path}.txt'
        )
    
    logger.info(f'Save path: {eval_save_path}')
    os.makedirs(os.path.dirname(eval_save_path), exist_ok=True)

    
    tokenizer = AutoTokenizer.from_pretrained(args.original_model_path)
    
    
    hack_pruned_mixtral_to_adapt()
    
    
    model_config_path = osp.join(args.pruned_model_path, "config.json")
    if osp.exists(model_config_path) and layer_experts:
        
        with open(model_config_path, "r") as f:
            pruned_model_config = json.load(f)
        logger.info(f"Load model config from {model_config_path}")
        logger.info(f"pruned_model_config: {args.config_path}")
        
        
        pruned_model_config["layer_num_experts"] = layer_experts
        logger.info(f"Adding layer_num_experts to config.json: {layer_experts}")
        
        
        with open(model_config_path, "w") as f:
            json.dump(pruned_model_config, f, indent=2)
    
    
    model = AutoModelForCausalLM.from_pretrained(
        args.pruned_model_path,
        torch_dtype=torch.float16,
        device_map="auto",
    )

    logger.info("Starting model evaluation on benchmark tasks")
    lm = HFLM(
        pretrained=model,
        tokenizer=tokenizer,
        batch_size=args.batch_size,
        device="cuda",
    )
    results = evaluator.simple_evaluate(
        model=lm,
        tasks=["wikitext", "winogrande", "arc_challenge", "arc_easy", "boolq", "hellaswag", "mmlu", "openbookqa", "rte"],
        num_fewshot=0,
        batch_size=args.batch_size,
        random_seed=args.seed,
        numpy_random_seed=args.seed,
        torch_random_seed=args.seed,
    )
    logger.info("Evaluation completed successfully")

    with open(eval_save_path, "a") as f:
        print(make_table(results), file=f)
        if "groups" in results:
            print(make_table(results, "groups"), file=f)
    
    logger.info(f"Results saved to {eval_save_path}")


if __name__ == '__main__':
    logging.basicConfig(
        format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
    )
    args = parse_args()
    main(args)
