# run_benchmark.py
import os
import shutil
import json
import time
import torch
import random
import logging
import gc
from tqdm import tqdm
import numpy as np

from .benchmarks.utils.eval_acc import run_gsm8k_eval, run_aime_eval, run_livecodebench_eval, run_mmlu_pro_eval, run_longbench_eval, run_longbenchv2_eval, run_longgenbench_eval, run_longwriter_eval, run_pg19_eval
from .benchmarks.narrativeqa import load_narrativeqa_dataset_answer
from .benchmarks.qasper import load_qasper_dataset_answer
from .benchmarks.multifieldqa_en import load_multifieldqa_en_dataset_answer
from .benchmarks.hotpotqa import load_hotpotqa_dataset_answer
from .benchmarks.musique import load_musique_dataset_answer
from .benchmarks._2wikimqa import load_2wikimqa_dataset_answer
from .benchmarks.gov_report import load_gov_report_dataset_answer
from .benchmarks.qmsum import load_qmsum_dataset_answer
from .benchmarks.multi_news import load_multi_news_dataset_answer
from .benchmarks.trec import load_trec_dataset_answer
from .benchmarks.triviaqa import load_triviaqa_dataset_answer
from .benchmarks.samsum import load_samsum_dataset_answer
from .benchmarks.passage_count import load_passage_count_dataset_answer
from .benchmarks.passage_retrieval_en import load_passage_retrieval_en_dataset_answer
from .benchmarks.lcc import load_lcc_dataset_answer
from .benchmarks.repobench_p import load_repobench_p_dataset_answer
from .benchmarks.pg19 import load_pg19_dataset

DATASET_LOADER = {
    "narrativeqa": load_narrativeqa_dataset_answer,
    "qasper": load_qasper_dataset_answer,
    "multifieldqa_en": load_multifieldqa_en_dataset_answer,
    "hotpotqa": load_hotpotqa_dataset_answer,
    "2wikimqa": load_2wikimqa_dataset_answer,
    "musique": load_musique_dataset_answer,  
    "gov_report": load_gov_report_dataset_answer,
    "qmsum": load_qmsum_dataset_answer,
    "multi_news": load_multi_news_dataset_answer,
    "trec": load_trec_dataset_answer, 
    "triviaqa": load_triviaqa_dataset_answer, 
    "samsum": load_samsum_dataset_answer,
    "passage_count": load_passage_count_dataset_answer,
    "passage_retrieval_en": load_passage_retrieval_en_dataset_answer,
    "lcc": load_lcc_dataset_answer,
    "repobench_p": load_repobench_p_dataset_answer,
    "pg19": load_pg19_dataset,
}

BENCHMARK_EVALUATORS = {
    "narrativeqa": run_longbench_eval,
    "qasper": run_longbench_eval,
    "multifieldqa_en": run_longbench_eval,
    "hotpotqa": run_longbench_eval,
    "2wikimqa": run_longbench_eval,
    "musique": run_longbench_eval,  
    "gov_report": run_longbench_eval,
    "qmsum": run_longbench_eval, 
    "multi_news": run_longbench_eval,
    "trec": run_longbench_eval,
    "triviaqa": run_longbench_eval,
    "samsum": run_longbench_eval,  
    "passage_count": run_longbench_eval,  
    "passage_retrieval_en": run_longbench_eval,  
    "lcc": run_longbench_eval,  
    "repobench_p": run_longbench_eval,
    "pg19": run_pg19_eval,
}

def main(builder, benchmarks=None, max_samples=None):
    torch.manual_seed(0)
    random.seed(0)
        
    # Enable profiling, disable logging profiling results
    builder.generator_profiling = True
    builder.profiling_verbose = False
    generator, tokenizer, past_kv, draft_past_kv = builder.build()
    args = builder.args
    
    # set logging level by environment variable
    LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper()
    logging.basicConfig(level=LOGLEVEL)
    
    # Build bench_list and check if all names are valid
    bench_list = benchmarks.split(",") if benchmarks is not None else []
    for b in bench_list:
        b = b.strip()
        if b.startswith("longbench_v2"):
            base_name = "longbench_v2"
        else:
            base_name = b
        
        if base_name not in DATASET_LOADER:
            raise ValueError(f"Unknown benchmark: {b}. Available benchmarks: {list(DATASET_LOADER.keys())}")
    print(f"Benchmarks to run: {bench_list}")
    
    # Handle output directories
    if args.out_dir is not None:
        shutil.rmtree(args.out_dir, ignore_errors=True)
        print(f"Deleted old {args.out_dir}")
        os.makedirs(args.out_dir, exist_ok=True)
        
    # Load target model config Ex. maxlen (for longbench testing)
    model2maxlen = json.load(open("./run/pipelines/benchmarks/utils/config/model2maxlen.json", "r"))
    max_length = model2maxlen[tokenizer.name_or_path]
    
    # Build base directory name
    log_dir_base = os.path.join(args.log_dir, str(args.llm_path.split("/")[-1]), str(type(builder).__name__))
    if args.draft_model_path is not None:
        d = str(args.draft_model_path.split("/")[-1]) + "-" + str(args.draft_params.max_depth) + "-" + str(args.draft_params.topk_len)
        if builder.generator_kwargs.get("Target_KV_size", None) is not None:
            d += "_" + str(builder.generator_kwargs["Target_KV_size"])
    else:
        d = "no_draft"
    log_dir_base = os.path.join(log_dir_base, d)

    
    # Run benchmarks 
    for bench_name in tqdm(bench_list, desc="Running benchmarks"):
        # fix random seed to 0 for each benchmark for reproducibility
        torch.manual_seed(0)
        random.seed(0)
        
        # Handle output directories
        log_dir = os.path.join(log_dir_base, bench_name)
        # Add min_length to log directory for pg19 benchmark
        if "pg19" in bench_name:
            log_dir = os.path.join(log_dir_base, bench_name + "_" + str(builder.batch_size) + "_" + str(getattr(args, "min_length", 1024*10)))
        os.makedirs(log_dir, exist_ok=True)
        print(f"Log directory: {log_dir}")
        
        # Load dataset
        if "pg19" in bench_name:
            # pg19 dataset is loaded differently
            input_len = getattr(args, "min_length", 1024*10)
            dataset = DATASET_LOADER["pg19"](tokenizer=tokenizer, input_len=input_len, max_samples=max_samples)
            num_samples = len(dataset)
            print(f"Running benchmark: {bench_name}, samples: {num_samples}")
        else:
            dataset = DATASET_LOADER[bench_name]()
            num_samples = min(len(dataset), max_samples) if max_samples is not None else len(dataset)
            print(f"Running benchmark: {bench_name}, samples: {num_samples}")
    
            #random.shuffle(dataset)
            dataset = dataset[:num_samples]
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.reset_peak_memory_stats()
    
        # Evaluate
        if BENCHMARK_EVALUATORS[bench_name] == run_longbench_eval:
            tput_mean, tput_std, tput_excl_target_prefill, tput_excl_all_prefill, acc_rate_mean, acc_rate_std, accuracy, avg_draft_time, avg_target_time, peak_mem = \
                BENCHMARK_EVALUATORS[bench_name](generator, tokenizer, past_kv, draft_past_kv, args, dataset, log_dir, bench_name, max_length)
        elif BENCHMARK_EVALUATORS[bench_name] == run_pg19_eval:
            tput_mean, tput_std, tput_excl_target_prefill, tput_excl_all_prefill, acc_rate_mean, acc_rate_std, avg_draft_time, avg_target_time, peak_mem = \
                BENCHMARK_EVALUATORS[bench_name](generator, tokenizer, past_kv, draft_past_kv, args, dataset, log_dir)
        else:
            raise ValueError(f"Unknown evaluator for benchmark: {bench_name}")
        
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.reset_peak_memory_stats()
        if hasattr(generator, "judge_acc_len_list"):
            # print("acc_list:", generator.judge_acc_len_list)
            tacc_judge_value = np.mean(generator.judge_acc_len_list)
        else:
            tacc_judge_value = 0
        
        # Calculate average self_attention latency
        if hasattr(generator, "all_attention_latencies") and generator.all_attention_latencies:
            all_attention_latencies = torch.tensor(generator.all_attention_latencies)
            all_attention_latencies = all_attention_latencies.sum(dim=1)  # sum over layers
            avg_self_attn_latency = np.mean(all_attention_latencies.cpu().numpy())
        else:
            avg_self_attn_latency = 0.0
        if hasattr(generator, "all_compresskv_latencies") and generator.all_compresskv_latencies:
            avg_compresskv_latency = np.mean(generator.all_compresskv_latencies)
        else:
            avg_compresskv_latency = 0.0
        if hasattr(generator, "all_criticality_estimation") and generator.all_criticality_estimation:
            avg_criticality_estimation = np.mean(generator.all_criticality_estimation)
        else:
            avg_criticality_estimation = 0.0

        if "longgenbench" in bench_name :
            # Write results to file
            with open(os.path.join(log_dir, "results.jsonl"), 'a+') as f:
                json.dump({
                    bench_name: {
                        "tput":         f"{tput_mean:.3f}",
                        "tput_std":     f"{tput_std:.3f}",
                        "Tacc":         f"{acc_rate_mean:.3f}",
                        "Tacc_std":     f"{acc_rate_std:.3f}",
                        "CR":     f"{accuracy['cr_mean']:.3f}",
                        "STIC1":     f"{accuracy['stic1_mean']:.3f}",
                        "STIC2":     f"{accuracy['stic2_mean']:.3f}",
                        "Accuracy":     f"{accuracy['wavg_mean']:.3f}",         
                        "avg_draft_time":  f"{avg_draft_time:.3f}",
                        "avg_target_time": f"{avg_target_time:.3f}",
                        "peak_memory":     f"{peak_mem:.3f} GiB",
                        "Tacc_judge" : f"{tacc_judge_value:.3f}",
                        "avg_self_attn_latency": f"{avg_self_attn_latency:.3f}",
                        "avg_compresskv_latency": f"{avg_compresskv_latency:.3f}",
                        "avg_criticality_estimation": f"{avg_criticality_estimation:.3f}",
                    }
                }, f, indent=4)
                f.write("\n")
        elif "pg19" in bench_name :
            # Write results to file
            with open(os.path.join(log_dir, "results.jsonl"), 'a+') as f:
                json.dump({
                    bench_name: {
                        "tput":         f"{tput_mean:.3f}",
                        "tput_std":     f"{tput_std:.3f}",
                        "tput_excl_target_prefill": f"{tput_excl_target_prefill:.3f}",
                        "tput_excl_all_prefill": f"{tput_excl_all_prefill:.3f}",
                        "Tacc":         f"{acc_rate_mean:.3f}",
                        "Tacc_std":     f"{acc_rate_std:.3f}",
                        "avg_draft_time":  f"{avg_draft_time:.3f}",
                        "avg_target_time": f"{avg_target_time:.3f}",
                        "peak_memory":     f"{peak_mem:.3f} GiB",
                        "Tacc_judge" : f"{tacc_judge_value:.3f}",
                        "avg_self_attn_latency": f"{avg_self_attn_latency:.3f}",
                        "avg_compresskv_latency": f"{avg_compresskv_latency:.3f}",
                        "avg_criticality_estimation": f"{avg_criticality_estimation:.3f}",
                    }
                }, f, indent=4)
                f.write("\n")
        else:
            # Write results to file
            with open(os.path.join(log_dir, "results.jsonl"), 'a+') as f:
                json.dump({
                    bench_name: {
                        "tput":         f"{tput_mean:.3f}",
                        "tput_std":     f"{tput_std:.3f}",
                        "tput_excl_target_prefill": f"{tput_excl_target_prefill:.3f}",
                        "tput_excl_all_prefill": f"{tput_excl_all_prefill:.3f}",
                        "Tacc":         f"{acc_rate_mean:.3f}",
                        "Tacc_std":     f"{acc_rate_std:.3f}",
                        "Accuracy":     f"{accuracy:.3f}",         
                        "avg_draft_time":  f"{avg_draft_time:.3f}",
                        "avg_target_time": f"{avg_target_time:.3f}",
                        "peak_memory":     f"{peak_mem:.3f} GiB",
                        "Tacc_judge" : f"{tacc_judge_value:.3f}",
                        "avg_self_attn_latency": f"{avg_self_attn_latency:.3f}",
                        "avg_compresskv_latency": f"{avg_compresskv_latency:.3f}",
                        "avg_criticality_estimation": f"{avg_criticality_estimation:.3f}",
                    }
                }, f, indent=4)
                f.write("\n")
