import argparse
import logging
import os

import torch
import wandb
from datasets import load_metric
from torch.optim import Adam

from zarya import Trainer, get_dataloaders, set_deterministic_mode
from zarya.model.adapter import adapter
from zarya.model.lora import lora
from zarya.model.model_cls import MODEL_CLS


def get_task_type(dataset_name):
    if dataset_name == "copa":
        return "multi"
    elif dataset_name == "stsb":
        return "regression"
    elif dataset_name == "squad":
        return "qa"
    else:
        return "classification"


def main(args):

    wandb.init(
        project=args.project_name,
        entity=args.entity,
        config=vars(args),
        tags=["turbososal"],
    )

    device = torch.device(args.device)
    torch.cuda.set_device(device)

    print(vars(args))

    task = (
        "glue"
        if args.name not in ("boolq", "cb", "multirc", "wic", "wsc", "copa", "record")
        else "super_glue"
    )
    task_type = get_task_type(args.name)
    dataloaders, num_labels = get_dataloaders(
        dataset_name=args.name,
        batch_size=args.batch_size,
        max_seq_len=args.seq_len,
        tokenizer_name=args.pretrain,
        seed=args.seed,
        task_type=task_type,
        task=task,
        cache_dir="/app/logs/",
    )
    metric_fn = load_metric(
        task,
        args.name if args.name not in ("multirc", "record") else "cb",
        keep_in_memory=True,
        cache_dir="/app/logs/",
    )

    pretrain = "roberta" if "roberta" in args.pretrain else "deberta"
    training_type = (
        args.training_type if "zero_fc" not in args.training_type else "zero"
    )
    model_cls = MODEL_CLS[task_type][training_type][pretrain]
    kwargs = {}
    if args.training_type not in (
        "full",
        "lora",
        "adapter",
        "bitfit"
    ):  # lora is built around full model for simplicity
        kwargs = {
            "prompt_length": args.prompt_length,
            "prompt_rank": args.prompt_rank,
            "training_type": args.training_type,
        }
    if get_task_type(args.name) in ("classification", "regression"):
        kwargs["num_labels"] = num_labels

    model = model_cls.from_pretrained(
        args.pretrain,
        hidden_dropout_prob=args.hidden_dropout,
        attention_probs_dropout_prob=args.attention_dropout,
        output_hidden_states=False,
        **kwargs,
    ).to(device)
    if training_type == "lora":
        lora(model, args.prompt_rank, pretrain)
        for n, p in model.named_parameters():
            if "_lora" not in n and "classifier" not in n:
                p.requires_grad = False
    if training_type == "adapter":
        adapter(model, args.prompt_rank, device)
        for n, p in model.named_parameters():
            if "adapter" not in n and "classifier" not in n and "LayerNorm" not in n:
                p.requires_grad = False
    if training_type == "bitfit":
        for n, p in model.named_parameters():
            if "bias" not in n and "classifier" not in n:
                p.requires_grad = False

    for n, p in model.named_parameters():
        if p.requires_grad:
            print(n)

    optimize_params = [p for p in model.parameters() if p.requires_grad]
    trainable_params = sum([p.numel() for p in optimize_params])
    all_params = sum([p.numel() for p in model.parameters()])
    print(f"ALL PARAMS {all_params}, TRAINABLE PARAMS {trainable_params}")
    optimizer_config = [
        {"params": optimize_params, "lr": args.lr, "weight_decay": args.weight_decay}
    ]

    optimizer = Adam(optimizer_config)

    if args.save_best_model:
        best_model_path = f"{args.log}/{wandb.run.name}_best.pth"
    else:
        best_model_path = None

    trainer = Trainer(
        model=model,
        optimizer=optimizer,
        scheduler=None,
        num_patience_steps=args.num_patience_steps,
        fp16=args.fp16,
        device=device,
        trainable_params=trainable_params,
    )
    trainer.train_model(
        dataloaders=dataloaders,
        epochs=args.epochs,
        metric_fn=metric_fn,
        dataset_name=args.name,
        task_type=task_type,
        debug=args.debug,
        best_model_save_path=best_model_path,
    )


if __name__ == "__main__":
    logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)

    parser = argparse.ArgumentParser()

    parser.add_argument("--project-name", type=str, default="aot")
    parser.add_argument("--entity", type=str, default="aot")
    parser.add_argument("--pretrain", type=str, default="roberta-base")
    parser.add_argument("--seq-len", type=int, default=128)
    parser.add_argument("--name", type=str, required=True)

    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--num-patience-steps", type=int, default=5)
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--weight-decay", type=float, default=0.0)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--hidden-dropout", type=float, default=0.0)

    parser.add_argument("--training-type", type=str)
    parser.add_argument("--prompt-length", type=int)
    parser.add_argument("--prompt-rank", type=int)

    parser.add_argument("--attention-dropout", type=float, default=0.0)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--log", type=str, help="path to log file", default="/app/logs")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--save-best-model", action="store_true")

    parser.add_argument("--debug", action="store_true")

    args = parser.parse_args()

    if not os.path.isdir(args.log):
        raise ValueError(f"Logs dir {args.log} doesn't exist")

    set_deterministic_mode(args.seed)
    main(args)
