import json
from metrics import get_all_evals, get_dataloader, get_eval_results
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import hydra
import transformers
from pathlib import Path
from utils import get_model_identifiers_from_yaml, set_random_seed
import os
import pickle
import shutil
import warnings
import csv
import subprocess
from peft import AutoPeftModelForCausalLM, PeftModel
warnings.filterwarnings('ignore')



def summary_results(eval_dir):
    for dirpath, dirnames, filenames in os.walk(eval_dir):
        for file in filenames:
            if file.endswith('json') and 'results' in file:
                results_path = os.path.join(dirpath, file)
    results = json.load(open(results_path, 'r'))['results']
    results_dict =  {
        'ARC-C': results['arc_challenge']['acc_norm,none'],
        'MMLU': results['mmlu']['acc,none'],
        'TruthfulQA(mc1)': results['truthfulqa_mc1']['acc,none'],
        'TriviaQA': results['triviaqa']['exact_match,remove_whitespace'],
        'GSM8k': results['gsm8k']['exact_match,flexible-extract'],
    }
    with open(os.path.join(eval_dir, "../downstream_task_results.txt"), 'w') as txtfile:
        for key, value in results_dict.items():
            txtfile.write(f"{key}: {value}\n")
    save_file = os.path.join(eval_dir, "../downstream_task_results.csv")
    with open(save_file, 'a') as f:
        w = csv.DictWriter(f, results_dict.keys())
        w.writeheader()
        w.writerow(results_dict)

    return results_dict

def general_eval(
    cfg,
    model_name,
    task_list=[
        "arc_challenge", # ARC-c
        "truthfulqa",
        "triviaqa",
        "mmlu",
        "gsm8k",
    ],
    output_dir=".",
):
    command = "accelerate"
    tasks = ",".join(task_list)
    if cfg.use_LoRA:
        model_args = f"pretrained={cfg.model_path},peft={model_name},add_bos_token=True,max_batch_size=16"
    else:
        model_args = f"pretrained={model_name},add_bos_token=True"

    args = [
        "launch", 
        "-m", 
        "lm_eval",
        "--model",
        "hf",
        "--model_args",
        model_args,
        "--tasks",
        f"{tasks}",
        "--batch_size",
        "auto:4",
        "--output_path",
        f"{output_dir}/downstream_tasks"
    ]
    # Combine command and arguments
    full_command = [command] + args
    # Execute the command
    print(full_command)
    try:
        subprocess.run(full_command, check=True)
    except subprocess.CalledProcessError as e:
        print(f"An error occurred: {e}")

    results_dict = summary_results(output_dir)
    return results_dict

def model_eval(cfg, model, tokenizer, save_dir,  eval_unlearn_step=None):
    eval_unlearn_step = 'last' if eval_unlearn_step == None else eval_unlearn_step
    aggregated_eval_logs = {}
    for i, (folder, split, question_key, answer_key, eval_task, base_answer_key, perturbed_answer_key) in enumerate(zip(cfg.eval.data_path, cfg.eval.split_list, cfg.eval.question_key, cfg.eval.answer_key, cfg.eval.eval_task, cfg.eval.base_answer_key, cfg.eval.perturbed_answer_key)):
        os.makedirs(save_dir, exist_ok=True)
        save_filename = os.path.join(save_dir, f"{eval_task}.json")

        if os.path.exists(save_filename):
            print(
                f"Skipping {eval_task} because {save_filename} already exists")
            eval_logs = json.load(open(save_filename, 'r'))
        else:
            eval_dataloader, base_eval_dataloader, perturb_dataloader = get_dataloader(
                cfg.eval, eval_task, tokenizer, folder, split, question_key, answer_key, base_answer_key, perturbed_answer_key)

            eval_logs = get_all_evals(cfg.eval,  model, tokenizer, folder, split, eval_task, eval_dataloader, base_eval_dataloader, perturb_dataloader, False)

            with open(save_filename, "w") as f:
                json.dump(eval_logs, f, indent=4)

        aggregated_eval_logs[f'{eval_task}.json'] = eval_logs

    aggregated_eval_log_filename = os.path.join(
        save_dir, "eval_log_aggregated.json")
    with open(aggregated_eval_log_filename, "w") as f:
        # pretty write json to f
        json.dump(aggregated_eval_logs, f, indent=4)

    eval_results = get_eval_results(aggregated_eval_logs)
    aaggregate_stat = {**eval_results,}
    
    print(aaggregate_stat)

    aaggregate_stat['split'] = cfg.split
    aaggregate_stat['forget_loss'] = cfg.forget_loss
    aaggregate_stat['forget_coeff'] = cfg.forget_coeff
    aaggregate_stat['regularization_coeff'] = cfg.regularization_coeff
    aaggregate_stat['learning_rate'] = cfg.lr
    aaggregate_stat['epochs'] = cfg.num_epochs
    aaggregate_stat['fix_ref_model'] = cfg.fix_ref_model
    aaggregate_stat['mask'] = cfg.mask
    aaggregate_stat['unlearn_step'] = eval_unlearn_step
    
    with open(os.path.join(save_dir, "unlearning_results.txt"), 'w') as txtfile:
        for key, value in aaggregate_stat.items():
            txtfile.write(f"{key}: {value}\n")

    save_file = os.path.join(save_dir, "unlearning_results.csv")
    with open(save_file, 'a') as f:
        w = csv.DictWriter(f, aaggregate_stat.keys())
        w.writeheader()
        w.writerow(aaggregate_stat)
    
    all_task_save_file = os.path.join(cfg.save_dir, "all_unlearning_results.csv")
    if not os.path.exists(all_task_save_file) or os.path.getsize(all_task_save_file) == 0:
        with open(all_task_save_file, 'a') as f:
            w = csv.DictWriter(f, aaggregate_stat.keys())
            w.writeheader()
            w.writerow(aaggregate_stat)
    else:
        with open(all_task_save_file, 'a') as f:
            w = csv.DictWriter(f, aaggregate_stat.keys())
            w.writerow(aaggregate_stat)

    return aaggregate_stat


@hydra.main(version_base=None, config_path="config", config_name="forget")
def main(cfg):
    if os.environ.get('LOCAL_RANK') is not None:
        local_rank = int(os.environ.get('LOCAL_RANK', '0'))
        device_map = {'': local_rank}

    seed = cfg.seed
    set_random_seed(seed)

    model_cfg = get_model_identifiers_from_yaml(cfg.model_family)
    model_id = model_cfg["hf_key"]
    

    curr_save_dir = cfg.save_dir
    curr_checkpoint_dir = os.path.join(curr_save_dir, f"checkpoint-{cfg.eval_unlearn_step}")
    if cfg.eval_unlearn_step == 0:
        curr_checkpoint_dir = cfg.model_path
    else:
        if not os.path.exists(curr_checkpoint_dir):
            print(f'{curr_checkpoint_dir} does not exist.')
            exit()

    curr_eval_dir = os.path.join(curr_save_dir, f'eval_results-{cfg.eval_unlearn_step}')
    
    # if os.path.exists(os.path.join(curr_eval_dir, 'unlearning_results.csv')):
    #     print(f'{curr_eval_dir} already evaluated.')
    #     exit()

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token

    config = AutoConfig.from_pretrained(model_id)
    # 
    if cfg.use_LoRA:
        model = AutoModelForCausalLM.from_pretrained(
            cfg.model_path,
            config=config,
            attn_implementation='flash_attention_2',
            torch_dtype=torch.bfloat16,
            device_map='auto'
        )
        model = PeftModel.from_pretrained(model, curr_checkpoint_dir)
        model = model.merge_and_unload()
    else:
        model = AutoModelForCausalLM.from_pretrained(
            curr_checkpoint_dir,
            config=config,
            attn_implementation='flash_attention_2',
            torch_dtype=torch.bfloat16,
            device_map='auto'
        )
    model = model.eval()
    
    eval_results = model_eval(cfg,  model, tokenizer, curr_eval_dir, cfg.eval_unlearn_step)
    print('After Unlearn Step %s,  Model Uility %.6f, Forget Efficacy %.6f' %
        (cfg.eval_unlearn_step, eval_results['Model Utility'], eval_results['Forget Efficacy']))
    
    task_lists = [
        "arc_challenge", # ARC-c
        "mmlu",
        "truthfulqa",
        "triviaqa",
        "gsm8k"
    ]
    
    del model
    
    geneal_results  = general_eval(cfg, curr_checkpoint_dir, task_lists, curr_eval_dir)
    print(geneal_results)
    all_results = {**geneal_results, **eval_results}

    # all_results = {**eval_results}
    with open(os.path.join(curr_eval_dir, "aggr_results.csv"), 'a') as f:
        w = csv.DictWriter(f, all_results.keys())
        w.writeheader()
        w.writerow(all_results)
    
    with open(os.path.join(curr_eval_dir, "aggregate_stat.txt"), 'w') as txtfile:
        for key, value in all_results.items():
            txtfile.write(f"{key}: {value}\n")


    if not cfg.save_checkpoint:
        # last unlearning tasks and do not save checkpoints
        if (os.path.exists(curr_checkpoint_dir)) and (cfg.eval_unlearn_step != 0):
            shutil.rmtree(curr_checkpoint_dir)

            # if len(task_list) > 1: # continual unlearning
            #     shutil.rmtree(curr_checkpoint_dir)
            #     print('Removed %s' % curr_checkpoint_dir)
            # else:
            #     # single unlearning
            #     # if curr_checkpoint_dir != 'checkpoint-last': # save the last checkpoint
            #         shutil.rmtree(curr_checkpoint_dir)
            #         print('Removed %s' % curr_checkpoint_dir)


if __name__ == "__main__":
    main()
