import os
import json
from argparse import ArgumentParser, Namespace
from collections import Counter

import numpy as np
import torch
from transformers import Trainer

import train_func
import settings
from helper import utils

def get_sel_size(sel_freq, tot_samples):
    if sel_freq == -1:
        sel_size = tot_samples
    elif sel_freq < 1:
        sel_size = int(tot_samples * sel_freq)
    else: sel_size = int(sel_freq)
    return sel_size

if __name__ == "__main__":
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    parser = ArgumentParser()
    parser.add_argument("--dataset", type=str, default="mnist",
                        help="Dataset name")
    parser.add_argument("--seed", type=int, default=1, 
                        help="Random seed")
    parser.add_argument("--save_dir", type=str, default="unlearn_outputs/",
                        help="Directory to output results")
    parser.add_argument("--model_dir", type=str, default="outputs/",
                        help="Directory that stores the trained model")
    parser.add_argument("--sel_level", default="instance",
                        choices=["class", "instance"],
                        help="Selection level")
    parser.add_argument("--sel_freq", type=float, default=-1,
                        help="Selection ratio [0,1), selection size > 1, all -1")
    parser.add_argument("--sel_sort_by", type=str, default="random",
                        choices=["random", "lowest_loss_first", "highest_loss_first"],
                        help="Sort selected points by certain criteria")
    parser.add_argument("--llama", action="store_true", 
                        help="Whether the model is llama")
    parser.add_argument("--vit", action="store_true", 
                        help="Whether the model is vit")
    args = parser.parse_args()
    utils.set_seed(args.seed)
    os.makedirs(args.save_dir, exist_ok=True)

    # initialize data and model (and tokenizer if necessary)
    if not (args.llama or args.vit):
        (train_set, test_set), model, model_config = settings.prepare_data_and_model(args.dataset)
    elif args.llama or args.hf_trainer:
        (train_set, test_set), (model, tokenizer), model_config = settings.prepare_data_and_model(args.dataset)
    elif args.vit:
        raise NotImplementedError

    # load training configurations
    with open(f"configs/train/{args.dataset}.json", "r") as f:
        config = json.load(f)
        print("train_config:", config)
        config = Namespace(**config)    # so that can access attributes through . operation

    # load the trained model if necessary
    if not args.sel_sort_by == "random":
        print("Loading the trained model...")
        if args.vit:
            from settings.vit_finetune.src.model import ClassificationModel
            model_path = f"{args.model_dir}/version_0/checkpoints/last.ckpt"
            model = ClassificationModel.load_from_checkpoint(model_path)
        else:
            model_path = f"{args.model_dir}/original.pt"
            with open(model_path, "rb") as f:
                checkpoint = torch.load(f)
                model.load_state_dict(checkpoint["init_state_dict"])    # avoid randomized initialization on frozen weights
                model.load_state_dict(checkpoint["state_dict"], strict=False)
                del checkpoint
            utils.clear_cache()
    
    # load the targets if necessary
    if args.sel_level == "class" or "loss" in args.sel_sort_by:
        if args.llama:
            targets = [x["label"] for x in train_set] 
        else: 
            # targets = torch.cat([x[1] for x in train_set])
            targets = [x[1] for x in train_set]
    
    # compute the losses if necessary
    if "loss" in args.sel_sort_by:
        train_loader = utils.get_dataloader(train_set,
                                            shuffle=False,
                                            batch_size=config.eval_batch_size)
        loss_fn = getattr(torch.nn, config.loss)(reduction="none")
        preds = train_func.inference(model, train_loader)
        preds = preds.argmax(-1)
        losses = loss_fn(preds, targets)
    
    
    print("========= SELECT FORGOTTEN POINTS =========")
    if args.sel_level == "instance":
        tot_samples = len(train_set)
        sample_indices = list(range(tot_samples))
        sel_size = get_sel_size(args.sel_freq, tot_samples)

    elif args.sel_level == "class":
        sample_class = np.random.randint(model_config["out_size"])
        # sample_class = 5  # cifar-10
        # sample_class = 1    # ag-news
        # sample_class = 1    # fashion-mnist
        targets = np.array(targets)
        sample_indices = np.where(targets == sample_class)[0]
        tot_samples = len(sample_indices)
        print("sample class: {} ({} samples)".format(sample_class, tot_samples))
        sel_size = get_sel_size(args.sel_freq, tot_samples)

    if args.sel_sort_by == "random":
        sel = np.random.choice(sample_indices, size=sel_size, replace=False)
        sel = sel.tolist()
    else:
        sorted_indices = np.argsort(losses.numpy()[sample_indices])
        if args.sel_sort_by == "highest_loss_first":
            sel = sorted_indices[:sel_size]
        elif args.sel_sort_by == "lowest_loss_first":
            sel = np.flip(sorted_indices)[:sel_size]
        else: raise NotImplementedError
        sel = [sample_indices[i] for i in sel]

    print("size:", len(sel))
    if args.sel_level == "class":
        print("label_counter:", Counter(targets[sel].flatten()))
        
    # if args.dataset != "samsum" and args.dataset !=  "xsum":
    #     print("selection_type:", args.sel_level, args.sel_sort_by)
    #     print("selected_acc:", (preds[sel].eq(targets[sel])).sum().item() / len(sel) * 100)
    #     print("non_selected_acc:", (preds[~sel].eq(targets[~sel])).sum().item() / len(sel) * 100)

    save_path = f"{args.save_dir}/unlearn_order.pt"
    torch.save(sel, save_path)
    print("Saved unlearn order to %s" % save_path)
    exit()

    if args.dataset in ["imdb", "samsum", "xsum", "agnews"]:
        (train_set, test_set), (model, tokenizer), model_cfg = settings.prepare_data_and_model(args.dataset)
        sel_size = get_sel_size(args.sel_freq, len(train_set))

        from adapters import AdapterTrainer
        from transformers import TrainingArguments

        model.load_state_dict(torch.load(open(model_path, "rb"))["state_dict"])
        training_args = TrainingArguments(output_dir=args.model_dir, 
                                          per_device_eval_batch_size=64)
        if args.llama:
            from peft import PromptTuningInit, PromptTuningConfig, TaskType
            from trl import SFTTrainer
            if args.dataset in ["imdb", "samsum"]:
                peft_config = PromptTuningConfig(
                    task_type=TaskType.CAUSAL_LM,
                    prompt_tuning_init=PromptTuningInit.TEXT,
                    num_virtual_tokens=20,
                    prompt_tuning_init_text="Summarize:",
                    tokenizer_name_or_path="meta-llama/Llama-2-7b-hf",
                )
                
                def prompt_formatter(sample):
                    return f"""<s>### Instruction:
                You are a helpful, respectful and honest assistant. \
                Your task is to summarize the following dialogue. \
                Your answer should be based on the provided dialogue only.

                ### Dialogue:
                {sample["dialogue"]}

                ### Summary:
                {sample["summary"]} </s>"""
            else:
                peft_config = PromptTuningConfig(
                    task_type=TaskType.CAUSAL_LM,
                    prompt_tuning_init=PromptTuningInit.TEXT,
                    num_virtual_tokens=20,
                    prompt_tuning_init_text="Classify:",
                    tokenizer_name_or_path="meta-llama/Llama-2-7b-hf",
                )
                
                def prompt_formatter(sample):
                    return f"""<s>### Instruction:
                You are a helpful, respectful and honest assistant. \
                Your task is to classify the following dialogue. \
                Your answer should be based on the provided dialogue only.

                ### Dialogue:
                {sample["text"]}

                ### Classification:
                {sample["label"]} </s>"""
            
            trainer = SFTTrainer(
                model=model,
                train_dataset=train_set,
                eval_dataset=test_set,
                peft_config=peft_config,
                max_seq_length=1024,
                tokenizer=tokenizer,
                packing=True,
                formatting_func=prompt_formatter, 
                args=training_args,
            )
        else:
            # trainer = AdapterTrainer(model=model, tokenizer=tokenizer)
            trainer = Trainer(model=model, tokenizer=tokenizer)
        
        if args.dataset in ["agnews"]:
            train_loader = trainer.get_eval_dataloader(trainer.train_dataset)
            loss_fn = torch.nn.CrossEntropyLoss(reduction="none")

            model.eval()
            targets = []
            logits = []
            with torch.no_grad():
                for inputs in train_set:
                    targets.append(inputs["label"])
            # targets = torch.cat(targets)
    else:
        (train_set, test_set), model, model_cfg = settings.prepare_data_and_model(args.dataset)
        if args.vit:
            from settings.vit_finetune.src.model import ClassificationModel
            model = ClassificationModel.load_from_checkpoint(model_path)
        else:
            model.load_state_dict(torch.load(open(model_path, "rb"))["state_dict"])
        model.to(DEVICE)

        train_loader = utils.get_dataloader(train_set,
                                            shuffle=False,  # so that the index will be the same when unlearned
                                            batch_size=128)
        if "loss" in args.sel_sort_by or args.sel_level == "class":
            print("concatenating targets")
            targets = torch.cat([batch[1] for batch in train_loader])

        if "loss" in args.sel_sort_by:
            print("inferencing and calculating loss")
            logits = train_func.inference(model, train_loader)
            loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
            preds = logits.argmax(dim=1).view(-1)
            loss = loss_fn(logits, targets)
