import os
import os.path as osp
import logging
import argparse
import yaml
from argparse import Namespace
from datetime import datetime
from typing import Union
import time

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from lm_eval.models.huggingface import HFLM
from lm_eval import evaluator
from lm_eval.utils import make_table
from accelerate import Accelerator
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
from transformers.models.deepseek_v2.modeling_deepseek_v2 import DeepseekV2ForCausalLM

from .method import METHODS
from data import DATASETS, build_calib_loader


logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--method', type=str, required=True,
                        choices=list(METHODS.keys()),
                        help=' '.join(['Supported pruning methods:'] + list(METHODS.keys())))
    parser.add_argument('--r', type=int, default=None,
                        help='Number of experts to preserve')
    parser.add_argument('--calib_set', type=str, default="c4",
                        choices=list(DATASETS.keys()),
                        help=' '.join(['Supported calibration datasets:'] + list(DATASETS.keys())))
    parser.add_argument('--model_path', type=str, default="/data/Mixtral-8x7B-v0.1",
                        help='Path to model to prune')
    parser.add_argument('--output_path', type=str, default='./output',
                        help='Output path (pruned model, pruning results, etc.)')
    parser.add_argument('--n_blocks_for_stat', type=int, default=128,
                        help='Number of sequences in calibration set. If set to 0 or negative, the whole dataset will be used')
    parser.add_argument('--batch_size', type=int, default=8,
                        help='Batch size for model inference')
    parser.add_argument('--num_workers', type=int, default=8,
                        help='Number of workers in dataloader')
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed for reproduction')
    parser.add_argument('--use_flash_attention_2', default=True,
                        help='If set, Flash Attention 2 will be used')
    parser.add_argument('--layer_index', type=int, default=None,
                        help='Layer index to prune')
    parser.add_argument('--evaluate_model', type=bool, default=True,
                        help='Whether to evaluate the model')
    
    
    parser.add_argument('--config_path', type=str, default=None,
                        help='Path to YAML config file containing experts configuration for different layers')
    parser.add_argument('--expert_config', type=str, default='prune_expert', choices=['prune_experts', 'merge_experts'],
                        help='Which expert configuration to use from config file (e.g., prune_experts, merge_experts), but please only use prune_expert')

    return parser.parse_args()


def adapt_prune(args: Namespace) -> Union[MixtralForCausalLM, DeepseekV2ForCausalLM]:
    logger.info(f'Arguments: {args}')

    if args.model_path.endswith('/'):
        args.model_path = args.model_path[:-1]
    model_name = args.model_path.split('/')[-1]

    
    layer_experts = None
    opt_order = None
    prune_metric = None
    prune_ratio = None
    merge_stat = None
    merge_metric = None
    merge_ratio = None
    
    if args.config_path is not None:
        try:
            with open(args.config_path, 'r') as f:
                config = yaml.safe_load(f)
            
            yaml_filename = args.config_path.split('/')[-1]
            if yaml_filename.endswith('.yaml'):
                yaml_filename = yaml_filename[:-5]  
            
            yaml_parts = yaml_filename.split('_')
            if len(yaml_parts) >= 8:  
                model_name = yaml_parts[0]
                opt_order = '_'.join(yaml_parts[1:3])  
                prune_metric = yaml_parts[3]
                prune_ratio = yaml_parts[4]
                merge_stat = yaml_parts[5]
                merge_metric = yaml_parts[6]
                merge_ratio = yaml_parts[7]
                logger.info(f"Parsed config filename: model={model_name}, opt_order={opt_order}, "
                          f"prune_metric={prune_metric}, prune_ratio={prune_ratio}, "
                          f"merge_stat={merge_stat}, merge_metric={merge_metric}, merge_ratio={merge_ratio}")
            
            if args.expert_config in config:
                layer_experts = config[args.expert_config]
                logger.info(f'Loaded {len(layer_experts)} layer expert configurations from {args.config_path}')
                for i, r in enumerate(layer_experts):
                    logger.info(f'Layer {i}: {r} experts')
            else:
                logger.warning(f'Expert configuration {args.expert_config} not found in config file. Using default r={args.r}')
        except Exception as e:
            logger.error(f'Error loading config file: {e}')
            logger.warning(f'Using default r={args.r} for all layers')

    
    if hasattr(args, 'model_output_dir') and args.model_output_dir is not None:
        model_save_path = args.model_output_dir
        logger.info(f'Using model_output_dir: {model_save_path}')

    
    eval_save_path = None
    if args.evaluate_model:
        if hasattr(args, 'result_save_path') and args.result_save_path is not None:
            eval_save_path = args.result_save_path
            logger.info(f'Using result_save_path: {eval_save_path}')
        

    logger.info(f'Save path: {model_save_path} \n eval_result will be saved in: {eval_save_path}')
    os.makedirs(model_save_path, exist_ok=True)
    set_seed(args.seed)

    
    if hasattr(args, 'provided_tokenizer') and args.provided_tokenizer is not None:
        tokenizer = args.provided_tokenizer
        logger.info("Using provided tokenizer")
    else:
        tokenizer = AutoTokenizer.from_pretrained(args.model_path)
        logger.info(f"Loaded tokenizer from {args.model_path}")
    
    if hasattr(args, 'provided_model') and args.provided_model is not None:
        model = args.provided_model
        logger.info("Using provided model")
    else:
        model = AutoModelForCausalLM.from_pretrained(
            args.model_path,
            device_map='auto',
            torch_dtype=torch.float16,
            
        )
        logger.info(f"Loaded model from {args.model_path}")

    calib_loader = build_calib_loader(args.calib_set, tokenizer,
                                      args.n_blocks_for_stat, args.batch_size, args.num_workers, args.seed)

    
    time_start = time.time()
    model, info = METHODS[args.method](model, calib_loader, args, layer_experts=layer_experts)
    logger.info(f"Time taken for {args.method}: {time.time() - time_start} seconds")
    
    if args.evaluate_model:
        print("Evaluating...")
        logger.info("Starting model evaluation on benchmark tasks - this may take a long time")

        lm = HFLM(
            pretrained=model,
            tokenizer=tokenizer,
            batch_size=args.batch_size,
            device_map='auto',
            parallelize=False
        )
        results = evaluator.simple_evaluate(
            model=lm,
            tasks=["winogrande", "arc_challenge", "arc_easy", "boolq", "hellaswag", "mmlu", "openbookqa", "rte"],
            num_fewshot=0,
            batch_size=args.batch_size,
            random_seed=0,
            numpy_random_seed=1234,
            torch_random_seed=1234,
        )
        logger.info("Evaluation completed successfully")
        
        
        mmlu_results_path = os.path.join(eval_save_path, "mmlu_results.txt")
        os.makedirs(os.path.dirname(mmlu_results_path), exist_ok=True)
        
        f = open(mmlu_results_path, "a")
        if args.layer_index is not None:
            print(f"Layer {args.layer_index} - r={args.r if args.r else layer_experts[args.layer_index] if layer_experts and args.layer_index < len(layer_experts) else 'unknown'} MMLU Results:", file=f)
        else:
            print(f"Full model MMLU Results:", file=f)
        print(make_table(results), file=f)
        if "groups" in results:
            print(make_table(results, "groups"), file=f)
        f.close()

        
        logger.info("Releasing memory after evaluation...")
        del lm, results
        torch.cuda.empty_cache()

    return model

if __name__ == '__main__':
    logging.basicConfig(
        format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
    )
    args = parse_args()
    adapt_prune(args)
