from datasets import load_dataset, concatenate_datasets
from evaluate import load
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup

from utils import add_adapters, set_active_task, freeze_adapters_thaw_base, freeze_base_thaw_adapters, instantiate_multi_model
from copy import deepcopy
import argparse
import os

def main(args):
    # Parse arguments
    model_name = args.model
    checkpoints = args.checkpoints
    tasks = args.tasks
    num_tasks = len(tasks)
    batch_size = args.batchsz
    num_epochs = args.num_epochs
    lr = args.lr
    lora_dim = args.lora_dim
    lora_alpha = args.lora_alpha
    lora_dropout = args.lora_dropout
    device = args.device
    save_path = args.save
    seed = args.seed
    verb = args.verbose

    if verb:
        print(f"Command Line Args: \n {args}")

    # Set seed for dataset shuffle
    torch.manual_seed(seed)


    # Load task data, metrics, and pretrained model
    actual_tasks = ["mnli" if task == "mnli-mm" else task for task in tasks]
    tasks_str = actual_tasks[0]
    for task in actual_tasks[1:]:
        tasks_str = tasks_str + "-" + task
    task_datas = [load_dataset("glue", actual_task, cache_dir="/var/local/nameredacted/.cache/huggingface/datasets") for actual_task in actual_tasks]
    metrics = [load("glue",actual_task, cache_dir="/var/local/nameredacted/.cache/huggingface/metrics") for actual_task in actual_tasks]
    model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir="/var/local/nameredacted/.cache/huggingface/transformers")

    # Preprocess data
    padding_side = "right"

    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=padding_side, cache_dir="/var/local/nameredacted/.cache/huggingface/tokenizers")
    if getattr(tokenizer, "pad_token_id") is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    task_to_keys = {
        "cola": ("sentence", None),
        "mnli": ("premise", "hypothesis"),
        "mnli-mm": ("premise", "hypothesis"),
        "mrpc": ("sentence1", "sentence2"),
        "qnli": ("question", "sentence"),
        "qqp": ("question1", "question2"),
        "rte": ("sentence1", "sentence2"),
        "sst2": ("sentence", None),
        "stsb": ("sentence1", "sentence2"),
        "wnli": ("sentence1", "sentence2"),
    }
    sentence_keys = [task_to_keys[task] for task in tasks]

    def tokenize_fn_gen(i):
        def tokenize_fn(examples):
            if sentence_keys[i][1] is None:
                return tokenizer(examples[sentence_keys[i][0]], truncation=True)
            return tokenizer(examples[sentence_keys[i][0]], examples[sentence_keys[i][1]], truncation=True, max_length=None)
        return tokenize_fn
    
    tokenized_datasets = []
    for i,sentence_key in enumerate(sentence_keys):
        rm_cols = [key for key in sentence_key if key is not None]
        rm_cols.append("idx")
        tokenized_datasets.append(task_datas[i].map(tokenize_fn_gen(i), batched=True, remove_columns=rm_cols))

    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
    # transformers library
    tokenized_datasets = [t.rename_column("label", "labels") for t in tokenized_datasets]
    def collate_fn(examples):
        return tokenizer.pad(examples, padding="longest", return_tensors="pt")

    # Instantiate dataloaders.
    train_dataloaders = [DataLoader(t["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size) for t in tokenized_datasets]
    eval_dataloaders = [DataLoader(concatenate_datasets((t["validation_matched"],t["validation_mismatched"])) , shuffle=False, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True)  if actual_tasks[i]=="mnli" else DataLoader(t["validation"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size) for (i,t) in enumerate(tokenized_datasets)]

    num_labels_list = [3 if task.startswith("mnli") else 1 if task=="stsb" else 2 for task in actual_tasks]

    model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir="/var/local/nameredacted/.cache/huggingface/transformers")
    add_adapters(model, adapter_dim=lora_dim, num_tasks=num_tasks, num_labels_list=num_labels_list, alpha=lora_alpha,p_dropout=lora_dropout)

    if len(checkpoints) > 0:
        checkpoint_data = [torch.load(checkpoint,map_location="cpu") for checkpoint in checkpoints]
        instantiate_multi_model(model,[checkpoint["model_state_dict"] for checkpoint in checkpoint_data])

    num_samples = [len(dloader.dataset) for dloader in train_dataloaders]
    num_batches_list = [len(dloader) for dloader in train_dataloaders]
    num_batches = np.min(num_batches_list)
    optimizer = AdamW(params=model.parameters(), lr=lr)
    lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=0.06 * (num_batches * num_epochs), num_training_steps=(num_batches * num_epochs))

    model.to(device)
    model.eval()
    for task_idx in range(num_tasks):
        eval_dataloader = eval_dataloaders[task_idx]
        set_active_task(model,task_idx)
        for batch in tqdm(eval_dataloader):
            batch.to(device)
            with torch.no_grad():
                outputs = model(**batch)
            predictions = outputs.logits
            if actual_tasks[task_idx] != "stsb":
                predictions = predictions.argmax(dim=-1)
            predictions, references = predictions, batch["labels"]
            metrics[task_idx].add_batch(
                predictions=predictions,
                references=references,
            )
        eval_metric = metrics[task_idx].compute()
        print(f"epoch {-1}, task {tasks[task_idx]}:", eval_metric)

    val_metrics = np.zeros((num_epochs,num_tasks))
    paths = [save_path + model_name + "-" + tasks_str + str(i) + ".pt" for i in range(num_epochs)]
    saved_arr = np.zeros(num_epochs, dtype=bool)

    for epoch in range(num_epochs):
        model.train()
        epoch_losses = np.zeros(num_tasks)
        iter_dloaders = [iter(dloader) for dloader in train_dataloaders]
        for _ in tqdm(range(num_batches)):
            try:
                batches = [next(dloader) for dloader in iter_dloaders]
            except StopIteration:
                break
            for task_idx, batch in enumerate(batches):
                set_active_task(model,task_idx)
                batch.to(device)
                outputs = model(**batch)
                loss = outputs.loss
                epoch_losses[task_idx] += loss.detach()
                loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            lr_scheduler.step()

        model.eval()
        eval_metrics = []
        for task_idx in range(num_tasks):
            eval_dataloader = eval_dataloaders[task_idx]
            set_active_task(model,task_idx)
            for batch in tqdm(eval_dataloader):
                batch.to(device)
                with torch.no_grad():
                    outputs = model(**batch)
                predictions = outputs.logits
                if actual_tasks[task_idx] != "stsb":
                    predictions = predictions.argmax(dim=-1)
                predictions, references = predictions, batch["labels"]
                metrics[task_idx].add_batch(
                    predictions=predictions,
                    references=references,
                )
            eval_metrics.append(metrics[task_idx].compute())
        print(f"epoch {epoch} Training Loss:", epoch_losses)
        for task_idx in range(num_tasks):
            print(f"epoch {epoch}, task {tasks[task_idx]}:", eval_metrics[task_idx])

        if len(save_path) > 0:
            curr_val_metrics = [list(eval_metric.values())[0] for eval_metric in eval_metrics]
            val_metrics[epoch,:] = curr_val_metrics

            save_worthy = True
            for i in range(epoch):
                # If current values are subsumed by another epoch, don't save
                if np.all(val_metrics[i,:] >= curr_val_metrics):
                    save_worthy = False
                # If current values subsume another epoch and it was saved, delete it
                if np.all(curr_val_metrics >= val_metrics[i,:]) and saved_arr[i]:
                    os.system("rm " + paths[i])
                    saved_arr[i] = False

            if save_worthy:
                # Save this one
                saved_arr[epoch] = True
                torch.save({
                    'seed': seed,
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'lr': lr,
                    'batchsz': batch_size,
                    'lora_dim': lora_dim,
                    'lora_alpha': lora_alpha,
                    'checkpoints': checkpoints,
                    'meta_lora_gd': True,
                    'val_metrics': curr_val_metrics
                    }, paths[epoch])
    return 

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description = "Run meta-lora algorithm using Roberta")
    valid_models = [
        "roberta-large",
        "roberta-base"
    ]
    valid_models.extend([m + "-openai-detector" for m in valid_models])
    parser.add_argument('--model', metavar='m', action="store", choices = ["roberta-large","roberta-base","roberta-large-openai"], type=str, help = "base model to use", default="roberta-large")
    parser.add_argument('--checkpoints', nargs="+", action="store", type=str, help = "checkpoints to load adapters", default = [])
    parser.add_argument('--tasks', metavar = 't', nargs='+', choices = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"], type=str, help = "glue tasks to use", default = ["cola","sst2"])
    parser.add_argument('--batchsz', metavar='b', action="store", type=int, help= "batch size", default = 32)
    parser.add_argument('--num_epochs', metavar='n', action="store", type=int, help = "number of epochs", default = 20)
    parser.add_argument('--lr', action="store", type=float, help = "learning rate", default = 3e-5)
    parser.add_argument('--lora_dim', action="store", type=int, help = "lora adapter dimension", default = 8)
    parser.add_argument('--lora_alpha', action="store", type=float, help = "lora alpha for scaling", default = 16)
    parser.add_argument('--lora_dropout', action="store", type=float, help = "lora dropout probability", default = .1)
    parser.add_argument('--device', action="store", type=str, help="device to train on", default = "cuda")
    parser.add_argument('--save', action="store", type=str, help="save model checkpoint path", default = "")
    parser.add_argument('--seed', action = "store", help="random seed", type = int, default = 613)
    parser.add_argument('--verbose', action = "store_true", help="verbose output", default = True)

    args = parser.parse_args()
    main(args)