import os
import copy
import json
import csv
import logging
from functools import partial
from argparse import ArgumentParser, Namespace

import numpy as np
import torch
import torch.nn
import torch.optim
import datasets
from torch.utils.data import DataLoader
from transformers import TrainingArguments

import settings
from helper import utils
from helper import eval_mia, eval_tofu, attack
from helper.thirdparty.tofu.dataloader import CustomTrainer
from helper.thirdparty.tofu.data_module import TextDatasetQA, AttackDatasetQA, custom_data_collator, custom_data_collator_with_indices


def get_eval_dataloader(
    data_path: str,
    split: str,
    tokenizer = None,
    model_family: str = None,
    question_key: str = None,
    answer_key: str = None,
    base_answer_key: str = None,
    perturbed_answer_key: str = None,
    eval_batch_size: int = 64,
    max_length: int = 128,
):
    init_eval = lambda q_key, a_key: TextDatasetQA(
        data_path=data_path,
        split=split,
        tokenizer=tokenizer,
        model_family=model_family,
        max_length=max_length,
        question_key=q_key,
        answer_key=a_key,
    )

    eval_data = init_eval(question_key, answer_key)
    base_eval_data = init_eval(question_key, base_answer_key)
    perturb_eval_data = init_eval(question_key, perturbed_answer_key)

    eval_loader = DataLoader(eval_data,
                             batch_size=eval_batch_size,
                             collate_fn=custom_data_collator_with_indices)
    base_eval_loader = DataLoader(base_eval_data,
                                  batch_size=eval_batch_size//4,
                                  collate_fn=custom_data_collator_with_indices)
    perturb_loader = DataLoader(perturb_eval_data,
                                batch_size=eval_batch_size//4,
                                collate_fn=custom_data_collator_with_indices)

    return eval_loader, base_eval_loader, perturb_loader




question_key = 'question'
answer_key = 'answer'
wtm_answer_key = 'answer_split'
base_answer_key = 'paraphrased_answer'
perturbed_answer_key = 'perturbed_answer'
holdout_split = 'real_authors'


if __name__ == '__main__':
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    UNLEARN_METHODS = [
        "original", 
        "retraining", 
        "ga", 
        "random_labels", 
        "gd", 
        "sgd", 
        "cr_newton", 
        "scr_newton", 
        "scr_newton_gdiff",
        "scrub",
        "dpo",
        "gdiff",
        "npo",
    ]
    parser = ArgumentParser()
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed')
    parser.add_argument("--save_dir", type=str, default="outputs/",
                        help="Directory to output results")
    parser.add_argument("--unlearn_method", type=str, default="retraining",
                        choices=UNLEARN_METHODS, 
                        help="Unlearning method")
    parser.add_argument("--output_dir", type=str, default="outputs/eval/",
                        help="Directory to output results")
    parser.add_argument('--retrain_output_dir', type=str, default=None,
                        help='Directory of retraining results')
    args = parser.parse_args()
    utils.set_seed(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)

    if args.retrain_output_dir is None:
        args.retrain_output_dir = args.output_dir
        print('retrain_output_dir is set to output_dir')

    # load configurations
    with open(f"configs/train/tofu.json", "r") as f:
        train_config = json.load(f)
    with open(f"configs/unlearn/tofu.json", "r") as f:
        unlearn_config = json.load(f)
        unlearn_config = unlearn_config[args.unlearn_method]
    
    config = utils.merge_dict(from_dict=train_config, to_dict=unlearn_config)
    config["device"] = DEVICE
    config["model_path"] = args.save_dir
    config["vit"] = False
    config["llama"] = True
    print("config:", config)
    config = Namespace(**config)

    # create and load model and tokenizer
    (train_ds, test_ds), (model, tokenizer), model_config = settings.prepare_data_and_model("tofu", config)
    
    # load trained model
    with open(config.model_path, "rb") as f:
        checkpoint = torch.load(f)
        model.load_state_dict(checkpoint["state_dict"], strict=False)
        del checkpoint
        print(f"Loaded checkpoint from {config.model_path}")
    utils.clear_cache()
    
    model.to(DEVICE)
    
    forget_dataset = datasets.load_dataset("locuslab/TOFU", "forget10")["train"]
    retain_dataset = datasets.load_dataset("locuslab/TOFU", "retain90")["train"]
    
    forget_data = TextDatasetQA(
        forget_dataset,
        tokenizer,
        model_family='llama2-7b',
        max_length=config.max_seq_length,
        question_key='question',
        answer_key='answer',
    )
    retain_data = TextDatasetQA(
        retain_dataset,
        tokenizer,
        model_family='llama2-7b',
        max_length=config.max_seq_length,
        question_key='question',
        answer_key='answer',
    )

    # load holdout data
    mia_forget_data = AttackDatasetQA(
        forget_dataset,
        tokenizer,
        model_family='llama2-7b',
        max_length=config.max_seq_length,
        question_key='question',
        answer_key='answer',
    )
    
    mia_holdout_data = AttackDatasetQA(
        data_path="locuslab/TOFU",
        split=holdout_split,
        tokenizer=tokenizer,
        model_family="llama2-7b",
        max_length=config.generation_max_length,
        question_key=question_key,
        answer_key=answer_key,
    )
    print('holdout_data_len:', len(mia_holdout_data))

    # # Run Membership Inference Attack
    print("TEMPORARILY DISABLED MIA")
    # print('=' * 10, 'Running MIA', '=' * 10)
    # # mia_auc, mia_log = eval_mia.eval_mia(forget_data, 
    # #                                      retain_data, 
    # #                                      holdout_data, 
    # #                                      model, 
    # #                                      tokenizer=tokenizer)
    
    # mia_scores_forget = attack.mink_plus_plus_attack(model, mia_forget_data, batch_size=config.eval_batch_size).tolist()
    # mia_scores_test = attack.mink_plus_plus_attack(model, mia_holdout_data, batch_size=config.eval_batch_size).tolist()

    # mia_scores_test = [x for x in mia_scores_test if not np.isnan(x) and x != float('inf') and x != float('-inf')]
    # mia_scores_forget = [x for x in mia_scores_forget if not np.isnan(x) and x != float('inf') and x != float('-inf')]
    # mia_features, mia_labels = attack.create_mia_dataset(mia_scores_forget, mia_scores_test)
    # print(len(mia_scores_test), len(mia_scores_forget))
    # print(np.mean(mia_scores_test), np.mean(mia_scores_forget))

    # fpr, tpr, auc_score, thresholds = attack.sweep(mia_features, mia_labels)

    # print('MIA AUC: ', auc_score)
    # path = os.path.join(args.output_dir, 'mia.json')
    # with open(path, 'w') as f:
    #     json.dump(auc_score, f, indent=4)

    # Run TOFU evaluation
    eval_split = 'forget{:02d}_perturbed'.format(int(0.1 * 100))     # forget_ratio = 0.05 -> eval_split = forget05_perturbed
    split_list = ['retain_perturbed', 'real_authors_perturbed', 'world_facts_perturbed', eval_split]    # original and paraphrased splits are inclusive
    # question_keys = ['question', 'question', 'question', 'question']
    # answer_keys = ['answer', 'answer', 'answer', 'answer']
    # base_answer_keys = ['paraphrased_answer', 'answer', 'answer', 'paraphrased_answer']
    # perturbed_answer_keys = ['perturbed_answer', 'perturbed_answer', 'perturbed_answer', 'perturbed_answer']
    # eval_tasks = ['eval_log', 'eval_real_author_wo_options', 'eval_real_world_wo_options', 'eval_log_forget']
    question_keys = ['question']
    answer_keys = ['answer']
    base_answer_keys = ['paraphrased_answer']
    perturbed_answer_keys = ['perturbed_answer']
    eval_tasks = ['eval_log_forget']

    aggregated_eval_logs = {}
    for split, question_key, answer_key, eval_task, base_answer_key, perturbed_answer_key in \
        zip(split_list, question_keys, answer_keys, eval_tasks, base_answer_keys, perturbed_answer_keys):
        print(f'Working on eval task {eval_task} with split {split}')

        eval_loaders = get_eval_dataloader(
            "locuslab/TOFU",
            split,
            tokenizer,
            "llama2-7b",
            question_key,
            answer_key,
            base_answer_key,
            perturbed_answer_key,
            eval_batch_size=config.eval_batch_size,
            max_length=config.max_seq_length,
        )
        eval_loader, base_eval_loader, perturb_loader = eval_loaders

        normalize_gt = False
        if 'eval_log' not in eval_task:
            normalize_gt = True
        eval_logs = eval_tofu.eval(
            model,
            tokenizer,
            "llama2-7b",
            eval_loader,
            base_eval_loader,
            perturb_loader,
            normalize_gt=normalize_gt,
            config=config,
        )
        aggregated_eval_logs[f'{eval_task}.json'] = eval_logs

        path = os.path.join(args.output_dir, f'tofu_{eval_task}.json')
        with open(path, 'w') as f:
            json.dump(eval_logs, f, indent=4)

    path = os.path.join(args.output_dir, 'tofu_eval_log_aggregated.json')
    with open(path, 'w') as f:
        json.dump(aggregated_eval_logs, f, indent=4)

    retain_result = json.load(open(os.path.join(args.retrain_output_dir, 'tofu_eval_log_aggregated.json')))
    ckpt_result = json.load(open(os.path.join(args.output_dir, 'tofu_eval_log_aggregated.json')))

    model_utility = eval_tofu.get_model_utility(ckpt_result)
    # forget_quality = eval_tofu.get_forget_quality(ckpt_result, retain_result)
    # model_utility['Forget Quality'] = forget_quality['Forget Quality']

    # dump the model utility to a temp.csv
    path = os.path.join(args.output_dir, 'tofu.csv')
    with open(path, 'w') as f:  # You will need 'wb' mode in Python 2.x
        w = csv.DictWriter(f, model_utility.keys())
        w.writeheader()
        w.writerow(model_utility)
