from datasets import load_dataset
from evaluate import load
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup

from utils import add_adapters, set_active_task, freeze_adapters_thaw_base, freeze_base_thaw_adapters, instantiate_multi_model
from copy import deepcopy
import argparse

def main(args):
    # Parse arguments
    model_name = args.model
    checkpoints = args.checkpoints
    tasks = args.tasks
    num_tasks = len(tasks)
    batch_size = args.batchsz
    num_epochs = args.num_epochs
    num_inner = args.num_inner
    lr = args.lr
    lora_dim = args.lora_dim
    lora_alpha = args.lora_alpha
    lora_dropout = args.lora_dropout
    device = args.device
    save_path = args.save
    seed = args.seed

    # Set seed for dataset shuffle
    torch.manual_seed(seed)


    # Load task data, metrics, and pretrained model
    actual_tasks = ["mnli" if task == "mnli-mm" else task for task in tasks]
    task_datas = [load_dataset("glue", actual_task, cache_dir="/var/local/nameredacted/.cache/huggingface/datasets") for actual_task in actual_tasks]
    metrics = [load("glue",actual_task, cache_dir="/var/local/nameredacted/.cache/huggingface/metrics") for actual_task in actual_tasks]
    model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir="/var/local/nameredacted/.cache/huggingface/transformers")

    # Preprocess data
    padding_side = "right"

    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=padding_side, cache_dir="/var/local/nameredacted/.cache/huggingface/tokenizers")
    if getattr(tokenizer, "pad_token_id") is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    task_to_keys = {
        "cola": ("sentence", None),
        "mnli": ("premise", "hypothesis"),
        "mnli-mm": ("premise", "hypothesis"),
        "mrpc": ("sentence1", "sentence2"),
        "qnli": ("question", "sentence"),
        "qqp": ("question1", "question2"),
        "rte": ("sentence1", "sentence2"),
        "sst2": ("sentence", None),
        "stsb": ("sentence1", "sentence2"),
        "wnli": ("sentence1", "sentence2"),
    }
    sentence_keys = [task_to_keys[task] for task in tasks]

    def tokenize_fn_gen(i):
        def tokenize_fn(examples):
            if sentence_keys[i][1] is None:
                return tokenizer(examples[sentence_keys[i][0]], truncation=True)
            return tokenizer(examples[sentence_keys[i][0]], examples[sentence_keys[i][1]], truncation=True, max_length=None)
        return tokenize_fn
    
    tokenized_datasets = []
    for i,sentence_key in enumerate(sentence_keys):
        rm_cols = [key for key in sentence_key if key is not None]
        rm_cols.append("idx")
        tokenized_datasets.append(task_datas[i].map(tokenize_fn_gen(i), batched=True, remove_columns=rm_cols))

    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
    # transformers library
    tokenized_datasets = [t.rename_column("label", "labels") for t in tokenized_datasets]
    def collate_fn(examples):
        return tokenizer.pad(examples, padding="longest", return_tensors="pt")

    # Instantiate dataloaders.
    train_dataloaders = [DataLoader(t["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size) for t in tokenized_datasets]
    eval_dataloaders = [DataLoader(t["validation"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size) for t in tokenized_datasets]

    num_labels_list = [3 if task.startswith("mnli") else 1 if task=="stsb" else 2 for task in actual_tasks]

    checkpoint_data = [torch.load(checkpoint,map_location="cpu") for checkpoint in checkpoints]
    model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir="/var/local/nameredacted/.cache/huggingface/transformers")
    add_adapters(model, adapter_dim=lora_dim, num_tasks=num_tasks, num_labels_list=num_labels_list, alpha=lora_alpha,p_dropout=lora_dropout)
    instantiate_multi_model(model,[checkpoint["model_state_dict"] for checkpoint in checkpoint_data])

    num_samples = [len(dloader.dataset) for dloader in train_dataloaders]
    optimizer = AdamW(params=model.parameters(), lr=lr)

    model.to(device)
    model.eval()
    for task_idx in range(num_tasks):
        eval_dataloader = eval_dataloaders[task_idx]
        set_active_task(model,task_idx)
        for batch in tqdm(eval_dataloader):
            batch.to(device)
            with torch.no_grad():
                outputs = model(**batch)
            predictions = outputs.logits.argmax(dim=-1)
            predictions, references = predictions, batch["labels"]
            metrics[task_idx].add_batch(
                predictions=predictions,
                references=references,
            )
        eval_metric = metrics[task_idx].compute()
        print(f"epoch {-1}, task {tasks[task_idx]}:", eval_metric)

    for epoch in range(num_epochs):
        model.train()
        freeze_adapters_thaw_base(model)
        epoch_losses = np.zeros(num_tasks)
        for task_idx in range(num_tasks):
            set_active_task(model,task_idx)
            train_dataloader = train_dataloaders[task_idx]
            for batch in tqdm(train_dataloader):
                batch.to(device)
                outputs = model(**batch)
                loss = outputs.loss/num_samples[task_idx]
                epoch_losses[task_idx] += loss.detach()
                loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print(f"epoch {epoch} Base Update Training Loss:", epoch_losses)


        freeze_base_thaw_adapters(model)
        for i in range(num_inner):
            epoch_losses = np.zeros(num_tasks)
            for task_idx in range(num_tasks):
                set_active_task(model,task_idx)
                train_dataloader = train_dataloaders[task_idx]
                for batch in tqdm(train_dataloader):
                    batch.to(device)
                    outputs = model(**batch)
                    loss = outputs.loss/num_samples[task_idx]
                    epoch_losses[task_idx] += loss.detach()
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()
            print(f"epoch {epoch}, inner step {i} LoRA Update Training Loss:", epoch_losses)
        model.eval()
        for task_idx in range(num_tasks):
            eval_dataloader = eval_dataloaders[task_idx]
            set_active_task(model,task_idx)
            for batch in tqdm(eval_dataloader):
                batch.to(device)
                with torch.no_grad():
                    outputs = model(**batch)
                predictions = outputs.logits
                if actual_tasks[task_idx] != "stsb":
                    predictions = predictions.argmax(dim=-1)
                predictions, references = predictions, batch["labels"]
                metrics[task_idx].add_batch(
                    predictions=predictions,
                    references=references,
                )
            eval_metric = metrics[task_idx].compute()
            print(f"epoch {epoch}, task {tasks[task_idx]}:", eval_metric)

        # Save checkpoint
        torch.save({
            'seed': seed,
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'lr': lr,
            'num_inner': num_inner,
            'batchsz': batch_size,
            'lora_dim': lora_dim,
            'lora_alpha': lora_alpha
            }, save_path + model_name + "-epoch{}.pt".format(epoch))
    return 

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description = "Run meta-lora algorithm using Roberta")
    valid_models = [
        "roberta-large",
        "roberta-base"
    ]
    valid_models.extend([m + "-openai-detector" for m in valid_models])
    parser.add_argument('--model', metavar='m', action="store", choices = ["roberta-large","roberta-base","roberta-large-openai"], type=str, help = "base model to use", default="roberta-large")
    parser.add_argument('--checkpoints', nargs="+", action="store", type=str, help = "checkpoints to load adapters")
    parser.add_argument('--tasks', metavar = 't', nargs='+', choices = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"], type=str, help = "glue tasks to use", default = ["cola","sst2"])
    parser.add_argument('--batchsz', metavar='b', action="store", type=int, help= "batch size", default = 32)
    parser.add_argument('--num_epochs', metavar='n', action="store", type=int, help = "number of epochs", default = 20)
    parser.add_argument('--num_inner', action="store", type=int, help = "number of inner epochs", default = 10)
    parser.add_argument('--lr', action="store", type=float, help = "learning rate", default = 3e-5)
    parser.add_argument('--lora_dim', action="store", type=int, help = "lora adapter dimension", default = 8)
    parser.add_argument('--lora_alpha', action="store", type=float, help = "lora alpha for scaling", default = 16)
    parser.add_argument('--lora_dropout', action="store", type=float, help = "lora dropout probability", default = .1)
    parser.add_argument('--device', action="store", type=str, help="device to train on", default = "cuda")
    parser.add_argument('--save', action="store", type=str, help="save model checkpoint path", default = "/var/local/nameredacted/model-checkpoints/")
    parser.add_argument('--seed', action = "store", help="random seed", type = int, default = 613)


    args = parser.parse_args()
    main(args)