import os
import time
import json
import copy
from argparse import ArgumentParser, Namespace
from collections import Counter

import settings.select_points
import torch
import numpy as np
from transformers import TrainingArguments

import methods
import settings
import utils
import train_func
import train_transformers_func


if __name__ == "__main__":
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    UNLEARN_METHODS = ["original", "retraining", "ga", "random_labels", "gd", "sgd", "fisher_newton", "cr_newton", "scr_newton", "pinv_newton", "damped_newton", "lbfgs", "ntk"]

    parser = ArgumentParser()
    parser.add_argument("--dataset", type=str, default="mnist",
                        help="Dataset name")
    parser.add_argument("--seed", type=int, default=1, 
                        help="Random seed")
    parser.add_argument("--divide", type=int, default=2, 
                        help="Number of continual rounds")
    parser.add_argument("--save_dir", type=str, default="unlearn_outputs/",
                        help="Directory to output the unlearning results")
    parser.add_argument("--exp_name", type=str, default="retraining",
                        help="Experiment name, will be appended to save_dir")
    # selection arguments
    parser.add_argument("--sel_level", default="instance",
                        choices=["class", "instance"],
                        help="Selection level")
    parser.add_argument("--sel_freq", type=float, default=-1,
                        help="Selection ratio [0,1), selection size > 1, all -1")
    parser.add_argument("--sel_sort_by", type=str, default="random",
                        choices=["random", "lowest_loss_first", "highest_loss_first"],
                        help="Sort selected points by certain criteria")
    # unlearning arguments
    parser.add_argument("--unlearn_method", type=str, default="retraining",
                        choices=UNLEARN_METHODS, 
                        help="Unlearning method")
    parser.add_argument("--unlearn_batch_size", type=int, default=1, 
                        help="Number of points to unlearn in batch")
    # special models
    parser.add_argument("--vit", action="store_true", 
                        help="Whether the model is ViT")
    parser.add_argument('--llama', action="store_true", 
                        help="Whether the model is Llama")
    args = parser.parse_args()
    utils.set_seed(args.seed)

    os.makedirs(args.save_dir, exist_ok=True)
    args.exp_dir = os.path.join(args.save_dir, args.exp_name)
    os.makedirs(args.exp_dir, exist_ok=True)
    debug_dir = os.path.join(args.exp_dir, "debug")
    os.makedirs(debug_dir, exist_ok=True)


    # load configurations
    with open(f"configs/train/{args.dataset}.json", "r") as f:
        train_config = json.load(f)
    with open(f"configs/unlearn/{args.dataset}.json", "r") as f:
        unlearn_config = json.load(f)
        unlearn_config = unlearn_config[args.unlearn_method]
    unlearn_config = utils.merge_dict(from_dict=train_config, 
                                      to_dict=unlearn_config)

    print("train_config:", train_config)
    print("unlearn_config:", unlearn_config)
    train_config = Namespace(**train_config)
    unlearn_config = Namespace(**unlearn_config)
    train_config.device = unlearn_config.device = DEVICE
    train_config.vit = unlearn_config.vit = args.vit
    train_config.llama = unlearn_config.llama = args.llama

    # initialize data and model (and tokenizer if necessary)
    if not (args.llama or args.vit):
        (entire_train_set, test_set), model, model_config = settings.prepare_data_and_model(args.dataset)
    elif args.llama:
        (entire_train_set, test_set), (model, tokenizer), model_config = settings.prepare_data_and_model(args.dataset, config.peft_config)
    elif args.vit:
        raise NotImplementedError

    # load the targets if necessary
    if args.sel_level == "class" or "loss" in args.sel_sort_by:
        if args.llama:
            entire_targets = np.array([x["label"] for x in entire_train_set])
        else: 
            entire_targets = np.array([x[1] for x in entire_train_set])

    all_logs = []
    for i in range(args.divide):
        print(f"========= TRAINING (ROUND {i+1}/{args.divide}) ==========")
        config = train_config

        # load the corresponding shard
        train_set = utils.shard(entire_train_set, 
                                num_shards=args.divide, 
                                index=i)
        targets = entire_targets[train_set.indices]

        # load model (after unlearned in the previous round)
        if i != 0:
            path = os.path.join(args.exp_dir, f"unlearn_model_round_{i-1}.pth")
            with open(path, "rb") as f:
                checkpoint = torch.load(f)
                model.load_state_dict(checkpoint["state_dict"], strict=False)
                del checkpoint
            utils.clear_cache()
 
        if args.llama:
            training_args = TrainingArguments(output_dir=args.save_dir,
                                              eval_strategy="steps",
                                              eval_steps=2,
                                              **config.training_arguments)
            peft_config = train_transformers_func.get_peft_config(config.peft_config)
            trainer_init_func = train_transformers_func.get_sft_trainer
            trainer_init_kwargs = Namespace(model=model,
                                            task_type=model_config["task"],
                                            train_dataset=train_set,
                                            test_dataset=test_set,
                                            peft_config=peft_config,
                                            tokenizer=tokenizer,
                                            max_seq_length=config.max_seq_length,
                                            args=training_args)
            trainer = trainer_init_func(**vars(trainer_init_kwargs))  # add peft parameters and freeze the remaining parameters
            
            init_state_dict = utils.get_trainable_state_dict(model)
            print("no_params (trainable): %d" % sum(p.numel() for p in model.parameters() if p.requires_grad))
            train_results = trainer.train() 
            trainer.log_metrics("train", train_results.metrics)
            utils.clear_cache()
            # with torch.no_grad():
            #     eval_metrics = trainer.evaluate()
            # trainer.log_metrics("eval", eval_metrics)
            model = trainer.model

        else:
            trainer_init_func = trainer_init_kwargs = None
            init_state_dict = copy.deepcopy(model.state_dict())
            print("no_params (trainable): %d" % sum(p.numel() for p in model.parameters() if p.requires_grad))

            train_loader = utils.get_dataloader(train_set,
                                                shuffle=True,
                                                batch_size=config.train_batch_size)
            test_loader = utils.get_dataloader(test_set,
                                            shuffle=False,
                                            batch_size=config.eval_batch_size)

            loss_fn = getattr(torch.nn, config.loss)()
            optimizer_cls = getattr(torch.optim, config.optimizer)
            optimizer = optimizer_cls(model.parameters(),
                                    lr=config.learning_rate,
                                    weight_decay=config.weight_decay)
            lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,
                                                                    lambda step: 0.5 if step % config.lr_update_interval == 0 else 1.0)

            metrics = train_func.train(model, 
                                       train_loader,
                                       loss_fn,
                                       optimizer,
                                       eval_dataloader=test_loader,
                                       num_epochs=config.num_epochs,
                                       log_frequency=config.log_frequency,
                                       lr_scheduler=lr_scheduler,
                                       device=config.device)
            learn_res = {}
            learn_res["type"] = "learn"
            learn_res["metrics"] = metrics
            all_logs.append(learn_res)

        path = os.path.join(args.exp_dir, f"train_model_round_{i}.pth")
        torch.save({
            "init_state_dict": init_state_dict,     # necessary for retraining
            "state_dict" : utils.get_trainable_state_dict(model),
            "pytorch-lightning_version": "2.2.1"    # for ViT model
        }, path)
        print(f"Saved trained model to {path}")

        print(f"========= SELECT FORGOTTEN POINTS (ROUND {i+1}/{args.divide}) =========")

        # define sample space
        if args.sel_level == "instance":
            tot_samples = len(train_set)
            sample_indices = list(range(tot_samples))
            sel_size = settings.select_points.get_sel_size(args.sel_freq, tot_samples)
        elif args.sel_level == "class":
            sample_class = 0    # np.random.randint(model_config["out_size"])
            sample_indices = np.where(targets == sample_class)[0]
            tot_samples = len(sample_indices)
            print("sample class: {} ({} samples)".format(sample_class, tot_samples))
            sel_size = settings.select_points.get_sel_size(args.sel_freq, tot_samples)

        # select points within sample space
        if args.sel_sort_by == "random":
            unlearn_order = np.random.choice(sample_indices, size=sel_size, replace=False)
            unlearn_order = unlearn_order.tolist()
        else:
            raise NotImplementedError

        print("size:", len(unlearn_order))
        if args.sel_level == "class":
            print("label_counter:", Counter(targets[unlearn_order].flatten()))
        
        # determine unlearning batch size
        if args.unlearn_batch_size == -1:   # unlearn all at once
            unlearn_batch_size = len(unlearn_order) 

        print(f"========= UNLEARNING (ROUND {i+1}/{args.divide}) ==========")
        print("Method:", args.unlearn_method)
        config = unlearn_config

        unlearn_round = 0
        accu_forget_ids = []
        dr_indices = list(range(len(train_set)))
        while len(unlearn_order) > 0:
            unlearn_round += 1
            print("unlearn_round:", unlearn_round)
            forget_ids = unlearn_order[:unlearn_batch_size]
            unlearn_order = unlearn_order[unlearn_batch_size:]
            accu_forget_ids.extend(forget_ids)
            retain_ids = set(range(len(train_set))).difference(set(accu_forget_ids))
            retain_ids = list(retain_ids)
            print("forget_size:", len(forget_ids))
            print("retain_size:", len(retain_ids))

            forget_set = utils.sample(train_set, forget_ids)
            retain_set = utils.sample(train_set, retain_ids)
            acc_forget_set = utils.sample(train_set, accu_forget_ids)

            start_time = time.time()
            func_args = (model, forget_set, retain_set, config)
            func_kwargs = {
                "trainer_init_func": trainer_init_func, 
                "trainer_init_kwargs": trainer_init_kwargs,
            }
            if args.unlearn_method == "original":
                pass
            elif args.unlearn_method == "retraining":
                func_kwargs["init_state_dict"] = init_state_dict
                methods.retrain(*func_args, **func_kwargs)
            elif args.unlearn_method == "gd":
                methods.gd(*func_args, **func_kwargs)
            elif args.unlearn_method == "sgd":
                methods.sgd(*func_args, **func_kwargs)
            elif args.unlearn_method == "ga":
                methods.ga(*func_args, **func_kwargs)
            elif args.unlearn_method == "random_labels":
                methods.random_labels(*func_args, **func_kwargs)
            elif args.unlearn_method == "pinv_newton":
                methods.pinv_newton(*func_args, **func_kwargs)
            elif args.unlearn_method == "damped_newton":
                methods.damped_newton(*func_args, **func_kwargs)
            elif args.unlearn_method == "cr_newton":
                methods.ga(*func_args, **func_kwargs)
                alpha = methods.cr_newton(*func_args, **func_kwargs)
            elif args.unlearn_method == "scr_newton":
                # methods.ga(*func_args, **func_kwargs)
                methods.scr_newton(*func_args, **func_kwargs)
            elif args.unlearn_method == "ntk":
                methods.ntk(*func_args, **func_kwargs)
            else: raise Exception

            running_time = time.time() - start_time
            unlearn_res = {}
            unlearn_res["type"] = "unlearn"
            unlearn_res["continual_round"] = i
            unlearn_res["unlearn_round"] = unlearn_round
            unlearn_res["running_time"] = round(running_time, 3)
            if args.unlearn_method == "cr_newton":
                unlearn_res["alpha"] = alpha
            
            utils.clear_cache()
            metrics = utils.evaluate_aft_unlearn(model, 
                                                forget_set, 
                                                acc_forget_set, 
                                                retain_set, 
                                                test_set,
                                                config,
                                                **func_kwargs)
            unlearn_res.update(metrics)
            print("****** Results ******")
            print(unlearn_res, sep="\n")
            print()

        with open(os.path.join(args.exp_dir, "metadata.json"), "w") as f:
            metadata = vars(config)
            metadata.pop("device")
            json.dump(metadata, f, indent=2)

        all_logs.append(unlearn_res)
        with open(f"{args.exp_dir}/stats.json", "w") as f:
            json.dump(all_logs, f, indent=2)

        path = os.path.join(args.exp_dir, f"unlearn_model_round_{i}.pth")
        torch.save({
            "init_state_dict": init_state_dict,     # necessary for retraining
            "state_dict" : utils.get_trainable_state_dict(model),
            "pytorch-lightning_version": "2.2.1"    # for ViT model
        }, path)
        print(f"Saved metadata and unlearned model to {path}")

        config.device = DEVICE  # restore device that is popped to save metadata