import csv
import json
import os
import shutil
import subprocess
import warnings

import hydra
import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

from metrics import get_all_evals, get_dataloader, get_eval_results
from utils import get_model_identifiers_from_yaml, set_random_seed

warnings.filterwarnings('ignore')


def summary_results(eval_dir):

    print(eval_dir)
    for dirpath, dirnames, filenames in os.walk(eval_dir):
        for file in filenames:
            if file.endswith('json') and 'results' in file:
                results_path = os.path.join(dirpath, file)


    print(results_path)
    results = json.load(open(results_path, 'r'))['results']
    # results_dict = {
    #     'ARC-C': results['arc_challenge']['acc_norm,none'],
    #     'MMLU': results['mmlu']['acc,none'],
    #     'TruthfulQA(mc1)': results['truthfulqa_mc1']['acc,none'],
    #     'TriviaQA': results['triviaqa']['exact_match,remove_whitespace'],
    #     'GSM8k': results['gsm8k']['exact_match,flexible-extract'],
    # }

    results_dict = {
        # 'ARC-C': results['arc_challenge']['acc_norm,none'],
        # 'MMLU': results['mmlu']['acc,none'],
        # 'TruthfulQA(mc1)': results['truthfulqa_mc1']['acc,none'],
        'TriviaQA': results['triviaqa']['exact_match,remove_whitespace'],
        # 'GSM8k': results['gsm8k']['exact_match,flexible-extract'],
    }

    with open(os.path.join(eval_dir, "../downstream_task_results.txt"), 'w') as txtfile:
        for key, value in results_dict.items():
            txtfile.write(f"{key}: {value}\n")
    save_file = os.path.join(eval_dir, "../downstream_task_results.csv")
    with open(save_file, 'a') as f:
        w = csv.DictWriter(f, results_dict.keys())
        w.writeheader()
        w.writerow(results_dict)

    return results_dict


def general_eval(
        cfg,
        model_name,
        task_list=[
            "arc_challenge",  # ARC-c
            "truthfulqa",
            "triviaqa",
            "mmlu",
            "gsm8k",
        ],
        output_dir=".",
):
    command = "accelerate"
    tasks = ",".join(task_list)
    if cfg.use_LoRA:
        model_args = f"pretrained={cfg.model_path},peft={model_name},add_bos_token=True,max_batch_size=16"
    else:
        model_args = f"pretrained={model_name},add_bos_token=True"

    args = [
        "launch",
        "-m",
        "lm_eval",
        "--model",
        "hf",
        "--model_args",
        model_args,
        "--tasks",
        f"{tasks}",
        "--batch_size",
        "auto:4",
        "--output_path",
        f"{output_dir}/downstream_tasks",
        "--gen_kwargs", 
        "max_new_tokens=256",
    ]
    # Combine command and arguments
    full_command = [command] + args
    # Execute the command
    print(full_command)
    try:
        subprocess.run(full_command, check=True)
    except subprocess.CalledProcessError as e:
        print(f"An error occurred: {e}")

    results_dict = summary_results(output_dir)
    return results_dict



def general_eval(
        cfg,
        model_name,
        task_list=[
            "arc_challenge",  # ARC-c
            "truthfulqa",
            "triviaqa",
            "mmlu",
            "gsm8k",
        ],
        output_dir=".",
):

    command = "accelerate"
    tasks = ",".join(task_list)

    # Common speed-ups for HF backend
    common = "add_bos_token=True,dtype=bfloat16,attn_implementation=flash_attention_2,trust_remote_code=True"
    if getattr(cfg, "use_LoRA", False):
        # Keep PEFT; for even more speed, merge LoRA offline and point pretrained to the merged folder
        model_args = f"pretrained={cfg.model_path},peft={model_name},{common}"
    else:
        model_args = f"pretrained={model_name},{common}"

    # Cache so re-runs don’t recompute
    # os.makedirs(output_dir, exist_ok=True)
    # cache_dir = os.path.join(output_dir, ".lm_cache")
    # os.makedirs(cache_dir, exist_ok=True)

    args = [
        "launch",
        "--mixed_precision", "bf16",          # use bf16 on H100
        # Optional: use both GPUs via accelerate config, or uncomment next line:
        # "--num_processes", "2",
        "-m", "lm_eval",
        "--model", "hf",
        "--model_args", model_args,
        "--tasks", tasks,
        "--batch_size", "auto:16",            # try auto:128 if it fits; drop to auto:64 if OOM
        "--max_batch_size", "128",
        "--gen_kwargs", "max_new_tokens=256,temperature=0",  # safe for TruthfulQA/TriviaQA/gsm8k (non-CoT)
        # "--use_cache", cache_dir,
        # "--cache_requests", "true",
        "--output_path", f"{output_dir}/downstream_tasks"
    ]

    full_command = [command] + args
    print(full_command)

    # Small runtime tweak: avoid tokenizer thread storms
    env = os.environ.copy()
    env["TOKENIZERS_PARALLELISM"] = "false"
    # Ensure both GPUs are visible (accelerate will decide how many to use)
    env.setdefault("CUDA_VISIBLE_DEVICES", "0,1")

    try:
        subprocess.run(full_command, check=True, env=env)
    except subprocess.CalledProcessError as e:
        print(f"An error occurred: {e}")

    results_dict = summary_results(output_dir)
    return results_dict




def model_eval(cfg, model, tokenizer, save_dir, eval_unlearn_step=None):
    eval_unlearn_step = 'last' if eval_unlearn_step == None else eval_unlearn_step
    aggregated_eval_logs = {}
    for i, (folder, split, question_key, answer_key, eval_task, base_answer_key, perturbed_answer_key) in enumerate(
            zip(cfg.eval.data_path, cfg.eval.split_list, cfg.eval.question_key, cfg.eval.answer_key, cfg.eval.eval_task,
                cfg.eval.base_answer_key, cfg.eval.perturbed_answer_key)):
        os.makedirs(save_dir, exist_ok=True)
        save_filename = os.path.join(save_dir, f"{eval_task}.json")

        if os.path.exists(save_filename):
            print(
                f"Skipping {eval_task} because {save_filename} already exists")
            eval_logs = json.load(open(save_filename, 'r'))
        else:
            eval_dataloader, base_eval_dataloader, perturb_dataloader = get_dataloader(
                cfg.eval, eval_task, tokenizer, folder, split, question_key, answer_key, base_answer_key,
                perturbed_answer_key)

            eval_logs = get_all_evals(cfg.eval, model, tokenizer, folder, split, eval_task, eval_dataloader,
                                      base_eval_dataloader, perturb_dataloader, False)

            with open(save_filename, "w") as f:
                json.dump(eval_logs, f, indent=4)

        aggregated_eval_logs[f'{eval_task}.json'] = eval_logs

    aggregated_eval_log_filename = os.path.join(
        save_dir, "eval_log_aggregated.json")
    with open(aggregated_eval_log_filename, "w") as f:
        # pretty write json to f
        json.dump(aggregated_eval_logs, f, indent=4)

    eval_results = get_eval_results(aggregated_eval_logs)
    aaggregate_stat = {**eval_results, }

    print(aaggregate_stat)

    aaggregate_stat['split'] = cfg.split
    aaggregate_stat['forget_loss'] = cfg.forget_loss
    aaggregate_stat['forget_coeff'] = cfg.forget_coeff
    aaggregate_stat['regularization_coeff'] = cfg.regularization_coeff
    aaggregate_stat['learning_rate'] = cfg.lr
    aaggregate_stat['epochs'] = cfg.num_epochs
    aaggregate_stat['fix_ref_model'] = cfg.fix_ref_model
    aaggregate_stat['mask'] = cfg.mask
    aaggregate_stat['unlearn_step'] = eval_unlearn_step

    with open(os.path.join(save_dir, "unlearning_results.txt"), 'w') as txtfile:
        for key, value in aaggregate_stat.items():
            txtfile.write(f"{key}: {value}\n")

    save_file = os.path.join(save_dir, "unlearning_results.csv")
    with open(save_file, 'a') as f:
        w = csv.DictWriter(f, aaggregate_stat.keys())
        w.writeheader()
        w.writerow(aaggregate_stat)

    all_task_save_file = os.path.join(cfg.save_dir, "all_unlearning_results.csv")
    if not os.path.exists(all_task_save_file) or os.path.getsize(all_task_save_file) == 0:
        with open(all_task_save_file, 'a') as f:
            w = csv.DictWriter(f, aaggregate_stat.keys())
            w.writeheader()
            w.writerow(aaggregate_stat)
    else:
        with open(all_task_save_file, 'a') as f:
            w = csv.DictWriter(f, aaggregate_stat.keys())
            w.writerow(aaggregate_stat)

    return aaggregate_stat


@hydra.main(version_base=None, config_path="config", config_name="forget")
def main(cfg):
    if os.environ.get('LOCAL_RANK') is not None:
        local_rank = int(os.environ.get('LOCAL_RANK', '0'))
        device_map = {'': local_rank}

    seed = cfg.seed
    set_random_seed(seed)

    model_cfg = get_model_identifiers_from_yaml(cfg.model_family)
    model_id = model_cfg["hf_key"]

    curr_save_dir = cfg.save_dir
    curr_checkpoint_dir = os.path.join(curr_save_dir, f"checkpoint-{cfg.eval_unlearn_step}")
    if cfg.eval_unlearn_step == 0:
        curr_checkpoint_dir = cfg.model_path
    else:
        if not os.path.exists(curr_checkpoint_dir):
            print(f'{curr_checkpoint_dir} does not exist.')
            exit()

    curr_eval_dir = os.path.join(curr_save_dir, f'eval_results-{cfg.eval_unlearn_step}')

    # if os.path.exists(os.path.join(curr_eval_dir, 'unlearning_results.csv')):
    #     print(f'{curr_eval_dir} already evaluated.')
    #     exit()

    
    if os.path.exists(os.path.join(curr_eval_dir, 'unlearning_results.csv')):
        print(f'{curr_eval_dir} already evaluated.')
    
    else:

        tokenizer = AutoTokenizer.from_pretrained(model_id)
        tokenizer.pad_token = tokenizer.eos_token

        config = AutoConfig.from_pretrained(model_id)
        # 
        if cfg.use_LoRA:
            model = AutoModelForCausalLM.from_pretrained(
                cfg.model_path,
                config=config,
                attn_implementation='flash_attention_2',
                torch_dtype=torch.bfloat16,
                device_map='auto'
            )
            model = PeftModel.from_pretrained(model, curr_checkpoint_dir)
            model = model.merge_and_unload()
        else:
            model = AutoModelForCausalLM.from_pretrained(
                curr_checkpoint_dir,
                config=config,
                attn_implementation='flash_attention_2',
                torch_dtype=torch.bfloat16,
                device_map='auto'
            )
        model = model.eval()

        eval_results = model_eval(cfg, model, tokenizer, curr_eval_dir, cfg.eval_unlearn_step)
        print('After Unlearn Step %s,  Model Uility %.6f, Forget Efficacy %.6f' %
            (cfg.eval_unlearn_step, eval_results['Model Utility'], eval_results['Forget Efficacy Hmean']))
        
        del model


    # task_lists = [
    #     "arc_challenge",  # ARC-c
    #     "mmlu",
    #     "truthfulqa",
    #     "triviaqa",
    #     "gsm8k"
    # ]

    task_lists = [
        # "mmlu",
        # "arc_challenge",
        # "truthfulqa",
        "triviaqa",
        # "gsm8k",   # last (heaviest if you ever switch to CoT)
    ]

    if os.path.exists(os.path.join(curr_save_dir, 'downstream_task_results.csv')):
        print(f'{curr_save_dir} already evaluated.')
    
    elif cfg.downstream:

        geneal_results = general_eval(cfg, curr_checkpoint_dir, task_lists, curr_eval_dir)
        print(geneal_results)
        all_results = {**geneal_results, **eval_results}

        with open(os.path.join(curr_eval_dir, "aggr_results.csv"), 'a') as f:
            w = csv.DictWriter(f, all_results.keys())
            w.writeheader()
            w.writerow(all_results)

        with open(os.path.join(curr_eval_dir, "aggregate_stat.txt"), 'w') as txtfile:
            for key, value in all_results.items():
                txtfile.write(f"{key}: {value}\n")

    if not cfg.save_checkpoint:
        # last unlearning tasks and do not save checkpoints
        if (os.path.exists(curr_checkpoint_dir)) and (cfg.eval_unlearn_step != 0):
            shutil.rmtree(curr_checkpoint_dir)


if __name__ == "__main__":
    main()
