import argparse
import os

import torch
from huggingface_hub import HfApi, login
from transformers import TrainingArguments

from egu.dataset.collators import DataCollatorQADPO, DataCollatorTDEC
from egu.dataset.tdec import TDEC
from egu.dataset.tofu import TOFU
from egu.models import HFModel, LoraTuneHFModel, SoftPromptHFModel
from egu.trainers import (
    DPOTrainer,
    GATrainer,
    IHLTrainer,
    KTOTrainer,
    NPOTrainer,
    SELUNoCalTrainer,
    SELUNoCeTrainer,
    SELUNoCplTrainer,
    SELUNoPairTrainer,
    SELUSTLTrainer,
    SELUSTTrainer,
    SELUTrainer,
)
from egu.trainers.utils import (
    PushLoRAEachEpochCallback,
    RenameCheckpointCallback,
    TDECExactSuccessStopCallback,
    TDECSuccessStopCallback,
    upload_hf_train_end,
)
from egu.utils.utils import load_yaml

NEEDS_REF_MODEL = {"dpo", "npo", "kto"}


if __name__ == "__main__":
    # if "HUGGINGFACE_HUB_TOKEN" in os.environ:
    #     login(token=os.environ["HUGGINGFACE_HUB_TOKEN"])
    #     api = HfApi()
    #
    parser = argparse.ArgumentParser(
        description="Evaluate Causal LM (local or HF Hub)."
    )
    parser.add_argument(
        "--model_id",
        type=str,
        required=True,
        help="Local path or HF repo id for the base model (or full finetuned model).",
    )

    parser.add_argument(
        "--model_variant",
        type=str,
        required=True,
        help="can be lora, softprompt or full",
    )

    parser.add_argument(
        "--lora_rank", type=int, required=False, help="setting for lora"
    )

    parser.add_argument(
        "--lora_alpha", type=int, required=False, help="setting for lora"
    )

    parser.add_argument(
        "--train_method",
        type=str,
        required=True,
        help="the training method you want say ga etc...",
    )

    parser.add_argument(
        "--dataset",
        type=str,
        required=True,
        help="The dataset you want to unlearn from",
    )

    parser.add_argument(
        "--dataset_split",
        type=str,
        required=True,
        help="the subdataset you want to train on for TOFU it is forget10",
    )

    args = parser.parse_args()

    model_name = args.model_id
    model_variant = args.model_variant
    train_method = args.train_method
    dataset_name = args.dataset
    dataset_split = args.dataset_split

    print(model_name)
    print(model_variant)
    print(train_method)
    # print(dataset_name)
    print(dataset_split)

    model_variants = {
        "lora": LoraTuneHFModel,
        "softprompt": SoftPromptHFModel,
        "full": HFModel,
    }
    models = {
        "tofu_llama-2-7b": {
            "hf_name": "open-unlearning/tofu_Llama-2-7b-chat-hf_full",
            "config_path": "egu/config/tofu_model_config",
        },
        "gpt-neo-2.7": {
            "hf_name": "EleutherAI/gpt-neo-2.7B",
            "config_path": "egu/config/tdec_model_config",
        },
        "gpt-neo-1.3": {
            "hf_name": "EleutherAI/gpt-neo-1.3B",
            "config_path": "egu/config/tdec_model_config",
        },
        "gpt-neo-0.125": {
            "hf_name": "EleutherAI/gpt-neo-125M",
            "config_path": "egu/config/tdec_model_config",
        },
    }
    datasets = {"tofu": TOFU, "tdec": TDEC}

    trainers = {
        "ga": GATrainer,
        "dpo": DPOTrainer,
        "ihl": IHLTrainer,
        "npo": NPOTrainer,
        "kto": KTOTrainer,
        "energy": EnergyTrainer,
        "energy-redirect": EnergyRedirectTrainer,
        "energy-fast": EnergyFastTrainer,
        "selu": SELUTrainer,
        "selu-better": SELUBetterTrainer,
        "selu-better-better": SELUBetterBetterTrainer,
        "selu-forget": SELUForgetTrainer,
        "sleu": SLEUTrainer,
        "conseal": ConSEALTrainer,
        "stowe": SToWETrainer,
        "tselu": TSELUTrainer,
        "stl-selu": SELUSTLTrainer,
        "st-selu": SELUSTTrainer,
        "selu-no-cpl": SELUNoCplTrainer,
        "selu-no-pair": SELUNoPairTrainer,
        "selu-no-cal": SELUNoCalTrainer,
        "selu-no-ce": SELUNoCeTrainer,
    }
    selected_model = models[model_name]
    model_hf_name = selected_model["hf_name"]
    model_config = selected_model["config_path"]

    cfg = load_yaml("egu/config/tofu_forget_1e-4.yaml")
    base_variant = {"config_path": model_config}
    repo_id = "msc_unlearn_{}_{}_{}_{}_{}".format(
        model_variant, cfg["lr"], train_method, dataset_name, dataset_split
    )

    if model_variant == "lora":
        lora_rank = args.lora_rank
        lora_alpha = args.lora_alpha
        base_variant["lora_rank"] = lora_rank
        base_variant["lora_alpha"] = lora_alpha
        repo_id = "msc_unlearn_{}_{}_{}_{}_{}_{}_{}".format(
            model_variant,
            lora_rank,
            lora_alpha,
            cfg["lr"],
            train_method,
            dataset_name,
            dataset_split,
        )
    print(base_variant)

    model = model_variants[model_variant](
        model_hf_name,
        **base_variant,
    )

    selected_dataset = datasets[dataset_name]

    dataset = selected_dataset(
        formatting_tokens=model.model_config["formatting_tokens"],
        eos_token=model.tokenizer.eos_token,
    )

    if dataset_name == "tdec":
        collator = DataCollatorTDEC(
            tokenizer=model.tokenizer,
            max_length=512,
        )
    elif dataset_name == "tofu":

        collator = DataCollatorQADPO(
            tokenizer=model.tokenizer,
            max_length=512,  # TODO: fix this so that it uses the max pad of current batch and not fixed
            model_configs=model.model_config["formatting_tokens"],
        )
    lr = cfg["lr"]

    # if torch.cuda.device_count() > 0:
    #
    #     # multiple GPU, then each gpu takes four samples, so linearly scale the learning rate
    #     lr = (
    #         cfg["lr"] * torch.cuda.device_count()
    #     )  # 4 in 4 gpu cases effective batch 32 and 32 / 8 = 4
    # else:
    #     lr = cfg["lr"]
    train_dataset = dataset.load_dataset_for_training(dataset_split)
    if dataset_name == "tdec":
        val_dataset = dataset.load_dataset_for_validation(dataset_split)
    else:
        val_dataset = dataset.load_dataset_for_training(dataset_split)

    print(f"my learning rate is {lr}")
    steps_per_epoch = len(train_dataset) // (
        cfg["batch_size"]
        * cfg["gradient_accumulation_steps"]
        * torch.cuda.device_count()
    )

    # max_steps = int(cfg.num_epochs * len(train_dataset)) // (
    #     cfg["batch_size"] * cfg["gradient_accumulation_steps"] * num_devices
    # )

    args = TrainingArguments(
        f"./results/{train_method}_{model_name}_{model_variant}_{cfg["lr"]}_{dataset_name}/{dataset_split}",
        per_device_train_batch_size=cfg["batch_size"],
        per_device_eval_batch_size=cfg["batch_size"],
        gradient_accumulation_steps=cfg[
            "gradient_accumulation_steps"
        ],  # number of updates steps to accumulate before backprop
        logging_steps=10,  # number of forward passes / batches before logging
        save_strategy="no",  # save once per epoch instead
        num_train_epochs=cfg["num_epochs"],
        warmup_steps=max(1, steps_per_epoch),
        remove_unused_columns=False,
        deepspeed="egu/config/ds_config_multi.json",
        learning_rate=lr,
        weight_decay=cfg["weight_decay"],
        fp16=True,
        # bf16=True,
        report_to="none",
        push_to_hub=False,  # allow trainer to push
        hub_model_id=repo_id,  # repo name on the Hub
        hub_strategy="end",  # push at each save
        hub_private_repo=True,
        save_on_each_node=False,
        ddp_find_unused_parameters=False,
        gradient_checkpointing=False,
    )

    list_callbacks = [
        RenameCheckpointCallback(),
        PushLoRAEachEpochCallback(repo_id=f"{repo_id}", tokenizer=model.tokenizer),
    ]

    # if dataset_name == "tdec":
    #     list_callbacks.append(
    #         TDECSuccessStopCallback(
    #             tokenizer=model.tokenizer,
    #             forget_ds=train_dataset,
    #             val_ds=val_dataset,
    #             n_gram=10,
    #             max_val_examples=None,
    #         )
    #     )

    common = dict(
        model=model.model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=model.tokenizer,
        data_collator=collator,
        callbacks=list_callbacks,
    )

    ref_model = None
    base_trainer_kwargs = common
    if train_method in NEEDS_REF_MODEL:
        ref_model = model_variants["full"](  # heavy init only when needed
            model_hf_name, config_path=model_config, ref=True
        )
        base_trainer_kwargs = {**common, "ref_model": ref_model.model}

    energy_only_kwargs = (
        # dict(
        #
        #     lambda_ce=1.0,
        #     lambda_e=1.0,
        #     tau_low=0.0,
        #     tau_high=1.0,
        #     margin=0.8,
        #     sample_forget=False,
        # )
        dict(
            lambda_ce=1.0,
            lambda_e=1.0,
            tau_low=0.0,
            tau_high=1.0,
            margin=0.5,
            sample_forget=False,
        )
        if train_method == "energy"
        else {}
    )

    trainer = trainers[train_method](
        **base_trainer_kwargs,
        **energy_only_kwargs,
    )

    trainer.train()
    upload_hf_train_end(trainer, model, train_method, repo_id)
