import os
import time
import json
import copy
from argparse import ArgumentParser, Namespace

import torch
import datasets
from transformers import TrainingArguments
from torch.utils.data import Subset 

import methods
import settings
from settings import PoisonedDataset
from settings import train_transformers_func
from helper import utils
from helper.thirdparty.tofu.dataloader import CustomTrainer
from helper.thirdparty.tofu.data_module import (
    TextDatasetQA, TextForgetDatasetQA, 
    ForgetDatasetQA, RetainDatasetQA,
    custom_data_collator, 
)

torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)

if __name__ == "__main__":
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    UNLEARN_METHODS = [
        "original", 
        "retraining", 
        "ga", 
        "gd", 
        "sgd", 
        "npo",
        "delete",
        "gdiff",
        "scrub",
        "idk",
        "pinv_newton", 
        "damped_newton", 
        "cr_newton", 
        "scr_newton", 
    ]

    parser = ArgumentParser()
    parser.add_argument("--seed", type=int, default=1, help="Random seed")
    parser.add_argument("--dataset", type=str, help="Dataset name")
    parser.add_argument("--model_path", type=str, default="outputs/original.pt",
                        help="Path to a trained model checkpoint")
    parser.add_argument("--save_dir", type=str, default="unlearn_outputs/",
                        help="Directory to output the unlearning results")
    parser.add_argument("--exp_name", type=str, default="retraining",
                        help="Experiment name, will be appended to save_dir")

    ######################## Training ######################## 
    parser.add_argument("--train_config_path", type=str,
                        help="Path to configuration file for training")
    parser.add_argument("--save_trainable", action="store_true",
                        help="If true, only trainable parameters will be saved. This option should be False if BatchNorm is used.")
    
    ######################## Unlearning ######################## 
    parser.add_argument("--unlearn_config_path", type=str,
                        help="Path to configuration file for unlearning, will override training config")
    parser.add_argument("--load_poison_data", action="store_true",
                        help="Whether the loaded data is poisonous data")
    parser.add_argument("--unlearn_method", type=str, default="retraining",
                        choices=UNLEARN_METHODS, 
                        help="Unlearning method")
    parser.add_argument("--unlearn_batch_size", type=int, default=1, 
                        help="Number of points to unlearn in batch")
    parser.add_argument("--save_every_round", action="store_true",
                        help="Whether to save the model after every unlearning round")
    parser.add_argument("--M", type=float, default=None, help="Lipschitz constant")

    ######################## ImageNet ######################## 
    parser.add_argument("--vit", action="store_true", 
                        help="Whether the model is ViT")

    ######################## TOFU ######################## 
    parser.add_argument("--selected_points", default=None, type=str,
                        help="selected_points for retain set")
    parser.add_argument("--llama", action="store_true", 
                        help="Whether the model is Llama")

    args = parser.parse_args()
    utils.set_seed(args.seed)

    os.makedirs(args.save_dir, exist_ok=True)
    args.exp_dir = os.path.join(args.save_dir, args.exp_name)
    os.makedirs(args.exp_dir, exist_ok=True)
    debug_dir = os.path.join(args.exp_dir, "debug")
    os.makedirs(debug_dir, exist_ok=True)

    # load configuration used in training
    with open(args.train_config_path, "r") as f:
        train_config = json.load(f)

    # load configuration for the selected unlearning method
    with open(args.unlearn_config_path, "r") as f:
        unlearn_config = json.load(f)
        unlearn_config = unlearn_config[args.unlearn_method]

    # load dataset and the order of points to unlearn
    if args.dataset == "tofu":
        unlearn_order = datasets.load_dataset("locuslab/TOFU", "forget10")["train"]
        retain_dataset = datasets.load_dataset("locuslab/TOFU", "retain90")["train"]
        if args.selected_points is not None:
            with open(f"{args.selected_points}", "rb") as f:
                test_ids = torch.load(f)
            retain_ids = set(range(len(retain_dataset))).difference(set(test_ids))
            retain_dataset = retain_dataset.select(retain_ids)
    else:
        with open(f"{args.save_dir}/unlearn_order.pt", "rb") as f:
            unlearn_order = torch.load(f)
    
    print("forget_size:", len(unlearn_order))
    if args.unlearn_batch_size == -1:   # unlearn all at once
        args.unlearn_batch_size = len(unlearn_order) 

    config = utils.merge_dict(from_dict=train_config, to_dict=unlearn_config)
    # config["device"] = DEVICE
    config["model_path"] = args.model_path
    config["vit"] = args.vit
    config["llama"] = args.llama
    config["save_dir"] = args.save_dir
    if args.dataset == "tofu":
        config["tofu"] = True
    else:
        config["tofu"] = False
    print("config:", config)
    config = Namespace(**config)

    # initialize data and model (and tokenizer if necessary)
    if not (args.llama or args.vit):
        (train_ds, test_ds), model, model_config = settings.prepare_data_and_model(args.dataset)
    elif args.llama:
        (train_ds, test_ds), (model, tokenizer), model_config = settings.prepare_data_and_model(args.dataset, config)
    elif args.vit:
        raise NotImplementedError
    
    if args.selected_points is not None:
        train_data = datasets.concatenate_datasets([retain_dataset, unlearn_order])
        train_ds = TextDatasetQA(
            train_data,
            tokenizer,
            model_family="llama2-7b",
            max_length=config.max_seq_length,
            question_key="question",
            answer_key="answer",
        )
    
    if args.load_poison_data:
        if isinstance(train_ds, datasets.Dataset):
            path = f"{args.save_dir}/poisoned_ds"
            train_ds = datasets.load_from_disk(path)
        else:
            path = f"{args.save_dir}/poisoned_ds.pt"
            with open(path, "rb") as f:
                train_ds = torch.load(f)
        print(f"Loaded poisoned train data from {path}")

    # load the trainer if possible
    if args.llama:
        if args.dataset == "ag_news":
            training_args = TrainingArguments(output_dir=args.exp_dir, **config.training_arguments)
            peft_config = train_transformers_func.get_peft_config(config.peft_config)
            trainer_init_func = train_transformers_func.get_sft_trainer
            trainer_init_kwargs = Namespace(
                model=model,
                task_type=model_config["task"],
                train_dataset=train_ds,
                test_dataset=test_ds,
                peft_config=peft_config,
                tokenizer=tokenizer,
                max_seq_length=config.max_seq_length,
                args=training_args,
            )
        elif args.dataset == "tofu":
            training_args = TrainingArguments(output_dir=args.exp_dir, **config.training_arguments)
            trainer_init_func = train_transformers_func.get_tofu_trainer
            trainer_init_kwargs = Namespace(
                model=model,
                train_dataset=train_ds,
                args=training_args,
                data_collator=custom_data_collator,
            )
        trainer = trainer_init_func(**vars(trainer_init_kwargs))  # add peft parameters and freeze the remaining parameters
        model = trainer.model
    else:
        trainer_init_func = trainer_init_kwargs = None

    # load the trained model, must call after initialize trainer for Llama to use peft-Llama
    print("no_params (trainable): %d" % sum(p.numel() for p in model.parameters() if p.requires_grad))

    if args.unlearn_method != "retraining":
        if args.dataset == "tofu":
            utils.load_trainable_model(model, config.model_path)
        else:
            with open(config.model_path, "rb") as f:
                checkpoint = torch.load(f)
                # model.load_state_dict(checkpoint["init_state_dict"])    # avoid randomized initialization on frozen weights (e.g. batch norm, non-peft parameters)
                model.load_state_dict(checkpoint["state_dict"], strict=False)
                del checkpoint
                print(f"Loaded checkpoint from {config.model_path}")
            utils.clear_cache()
        if args.unlearn_method == "ntk":
            init_state_dict = copy.deepcopy(model.state_dict())
    else:
        init_state_dict = copy.deepcopy(model.state_dict())
    
    if args.unlearn_method in ("KL", "scrub") and args.dataset == "tofu":
        (_, _), (oracle_model, _), _ = settings.prepare_data_and_model(args.dataset, config)
        oracle_model.load_state_dict(model.state_dict())
        oracle_model.to(DEVICE)
    
    model.to(DEVICE)

    print("========= SEQUENTIAL UNLEARNING ==========")
    print("Method:", args.unlearn_method)

    unlearn_round = 0
    all_logs = []
    accu_forget_ids = []
    dr_indices = list(range(len(train_ds)))
    while len(unlearn_order) > 0:
        unlearn_round += 1
        print("unlearn_round:", unlearn_round)
        if args.dataset == "tofu":
            accu_forget_ids = datasets.Dataset.from_dict({"question": [], "answer": []})
            forget_ids = unlearn_order[:args.unlearn_batch_size]
            forget_ids = datasets.Dataset.from_dict(forget_ids)
            unlearn_order = unlearn_order[args.unlearn_batch_size:]
            unlearn_order = datasets.Dataset.from_dict(unlearn_order)
            accu_forget_ids = datasets.concatenate_datasets([accu_forget_ids, forget_ids])
            remaining_forget = unlearn_order
            retain_ids = datasets.concatenate_datasets([retain_dataset, remaining_forget])
        else:
            forget_ids = unlearn_order[:args.unlearn_batch_size]
            unlearn_order = unlearn_order[args.unlearn_batch_size:]
            accu_forget_ids.extend(forget_ids)
            retain_ids = list(set(range(len(train_ds))).difference(set(accu_forget_ids)))
        print("forget_size:", len(forget_ids))
        print("retain_size:", len(retain_ids))

        if args.dataset == "tofu":
            if args.unlearn_method in ("idk", "gdiff", "npo"):
                # each sample in merge_forget_retain_set is a pair of (forget_sample, retain_sample)
                merge_forget_retain_set = TextForgetDatasetQA(
                    forget_ids,
                    retain_ids,
                    tokenizer,
                    model_family="llama2-7b",
                    max_length=config.max_seq_length,
                    loss_type=args.unlearn_method,
                )
                forget_set = None
                retain_set = None
            elif args.unlearn_method == "scrub":
                forget_loss = args.unlearn_method
                forget_set = ForgetDatasetQA(forget_ids, tokenizer, model_family="llama2-7b")
                retain_set = RetainDatasetQA(retain_ids, tokenizer, model_family="llama2-7b")
            else:
                get_encoded_dataset = lambda data: TextDatasetQA(
                    data,
                    tokenizer,
                    model_family="llama2-7b",
                    max_length=config.max_seq_length,
                    question_key="question",
                    answer_key="answer"
                )
                forget_set = get_encoded_dataset(forget_ids)
                retain_set = get_encoded_dataset(retain_ids)
                accu_forget_set = get_encoded_dataset(accu_forget_ids)
        else:
            forget_set = Subset(train_ds, forget_ids)
            retain_set = Subset(train_ds, retain_ids)
            accu_forget_set = Subset(train_ds, accu_forget_ids)

        # For debugging purpose, check the labels in the forget set
        if args.dataset != "tofu":
            all_targets = [x[1] for x in forget_set]
            print("Unique forget labels:", set(all_targets))

        start_time = time.time()
        if args.unlearn_method == "retraining":
            model.load_state_dict(init_state_dict)

        match args.unlearn_method:
            case "original":
                unlearner = None
            case "retraining":
                unlearner = methods.retrain
            case "ga":
                unlearner = methods.ga
            case "gd":
                unlearner = methods.gd
            case "gdiff":
                if config.tofu:
                    unlearner = methods.gdiff_tofu
                else:
                    unlearner = methods.gdiff
            case "sgd":
                unlearner = methods.sgd
            case "delete":
                unlearner = methods.delete
            case "ntk":
                unlearner = methods.ntk
            case "scrub":
                if config.tofu:
                    unlearner = methods.scrub_tofu
                else:
                    unlearner = methods.scrub
            case "idk":
                assert config.tofu, f"IDK is only applicable to TOFU dataset."
                unlearner = methods.idk
            case "npo":
                unlearner = methods.npo
            case "pinv_newton":
                unlearner = methods.pinv_newton
            case "damped_newton":
                unlearner = methods.damped_newton
            case "cr_newton":
                unlearner = methods.cr_newton
            case "scr_newton":
                unlearner = methods.scr_newton
            case "scr_newton_gdiff":
                unlearner = methods.scr_newton_gdiff

        unl_logs = {}
        kwargs = {}
        if args.M is not None:
            kwargs = {"M": args.M}
            
        if args.unlearn_method != "original":
            if args.dataset == "tofu" and args.unlearn_method in ("scrub", "gdiff", "idk"):
                unlearner(
                    model, tokenizer, merge_forget_retain_set, config, training_args
                )
            else:
                unlearner(
                    model, forget_set, retain_set, config, 
                    trainer_init_func=trainer_init_func, 
                    trainer_init_kwargs=trainer_init_kwargs, 
                    unl_logs=unl_logs, 
                    device=DEVICE,
                    **kwargs
                )
        
        res = {}
        res["unlearn_round"] = unlearn_round
        res["running_time"] = time.time() - start_time
        if args.unlearn_method == "cr_newton":
            res["alpha"] = unl_logs["alpha"]
        
        if args.dataset == "tofu":
            print("Please run TOFU evaluation script separately.")
        else:
            metrics = utils.eval_unlearning(
                model, accu_forget_set, retain_set, test_ds, config, 
                trainer_init_func=trainer_init_func,
                trainer_init_kwargs=trainer_init_kwargs,
                device=DEVICE,
            )
            res.update(metrics)
            print("****** Results ******")
            print(res, sep="\n")
            print()

        all_logs.append(res)
        with open(os.path.join(args.exp_dir, "stats.json"), "w") as f:
            json.dump(all_logs, f, indent=2)

        with open(os.path.join(args.exp_dir, "metadata.json"), "w") as f:
            metadata = vars(config).copy()
            json.dump(metadata, f, indent=2)

        if args.save_every_round:
            path = os.path.join(debug_dir, f"model_round-{unlearn_round}.pt")
            utils.save_model(model, path, save_trainable=args.save_trainable)
        
        # exit()

    path = os.path.join(args.exp_dir, "model.pt")
    utils.save_model(model, path, save_trainable=args.save_trainable)

