from datasets import load_dataset,concatenate_datasets
from evaluate import load
import torch
from numpy import round
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup

from utils import add_adapters, set_active_task, freeze_base_thaw_adapters, freeze_head, instantiate_base_model
from copy import deepcopy
import argparse

def main(args):
    # Parse arguments
    model_name = args.model
    pretrained_base_path = args.from_pretrained_base
    task = args.task
    batch_size = args.batchsz
    num_epochs = args.num_epochs
    lr = args.lr
    lora_dim = args.lora_dim
    lora_alpha = args.lora_alpha
    lora_dropout = args.lora_dropout
    device = args.device
    save_path = args.save
    seed = args.seed
    only_qv = args.only_qv
    verb = args.verbose

    if verb:
        print(f"Command Line Args: \n {args}")

    # Set seed for dataset shuffle
    torch.manual_seed(seed)

    # Load task data, metrics, and pretrained model
    actual_task = "mnli" if task == "mnli-mm" else task
    task_data = load_dataset("glue", actual_task, cache_dir="/var/local/nameredacted/.cache/huggingface/datasets")
    metric = load("glue",actual_task, cache_dir="/var/local/nameredacted/.cache/huggingface/metrics")
    num_labels = 3 if task.startswith("mnli") else 1 if task=="stsb" else 2
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels = num_labels, cache_dir="/var/local/nameredacted/.cache/huggingface/transformers")

    # Preprocess data
    padding_side = "right"

    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=padding_side, cache_dir="/var/local/nameredacted/.cache/huggingface/tokenizers")
    if getattr(tokenizer, "pad_token_id") is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    task_to_keys = {
        "cola": ("sentence", None),
        "mnli": ("premise", "hypothesis"),
        "mnli-mm": ("premise", "hypothesis"),
        "mrpc": ("sentence1", "sentence2"),
        "qnli": ("question", "sentence"),
        "qqp": ("question1", "question2"),
        "rte": ("sentence1", "sentence2"),
        "sst2": ("sentence", None),
        "stsb": ("sentence1", "sentence2"),
        "wnli": ("sentence1", "sentence2"),
    }
    sentence_keys = task_to_keys[task]
    def tokenize_function(examples):
        if sentence_keys[1] is None:
            return tokenizer(examples[sentence_keys[0]], truncation=True)
        return tokenizer(examples[sentence_keys[0]], examples[sentence_keys[1]], truncation=True, max_length=None)
    

    rm_cols = [key for key in sentence_keys if key is not None]
    rm_cols.append("idx")
    tokenized_dataset = task_data.map(tokenize_function, batched=True, remove_columns=rm_cols)

    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
    # transformers library
    tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
    def collate_fn(examples):
        return tokenizer.pad(examples, padding="longest", return_tensors="pt")

    # Instantiate dataloaders
    train_dataloader =  DataLoader(tokenized_dataset["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True)
    if actual_task == "mnli":
        eval_dataloader = DataLoader(concatenate_datasets((tokenized_dataset["validation_matched"],tokenized_dataset["validation_mismatched"])) , shuffle=False, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True)
    else:
        eval_dataloader = DataLoader(tokenized_dataset["validation"] , shuffle=False, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True)

    add_adapters(model, adapter_dim=lora_dim, num_tasks=1, num_labels_list=[num_labels], alpha=lora_alpha,p_dropout=lora_dropout, only_qv = only_qv)
    if len(pretrained_base_path) > 0:
        checkpoint_base = torch.load(pretrained_base_path,map_location="cpu")
        instantiate_base_model(model,checkpoint_base['model_state_dict'])
        base_metrics = checkpoint_base["val_metrics"]
        print(f"Base Validation Metrics: {base_metrics}")

    optimizer = AdamW(params=model.parameters(), lr=lr)
    lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs), num_training_steps=(len(train_dataloader) * num_epochs))
    
    freeze_base_thaw_adapters(model)
    
    model.to(device)
    best_val_metric = 0
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        for step, batch in enumerate(tqdm(train_dataloader)):
            batch.to(device)
            outputs = model(**batch)
            loss = outputs.loss
            epoch_loss += loss.detach()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            lr_scheduler.step()
            del outputs
            del loss
            del batch
        print(f"epoch {epoch} Training Loss:", epoch_loss)

        model.eval()
        for batch in tqdm(eval_dataloader):
            batch.to(device)
            with torch.no_grad():
                outputs = model(**batch)
            predictions = outputs.logits
            if task != "stsb":
                predictions = predictions.argmax(dim=-1)
            predictions, references = predictions, batch["labels"]
            metric.add_batch(
                predictions=predictions,
                references=references,
            )

        eval_metric = metric.compute()
        if verb:
            print(f"epoch {epoch} Eval Metric:", eval_metric)

        curr_val_metric = list(eval_metric.values())[0]
        if curr_val_metric > best_val_metric:
            best_val_metric = curr_val_metric
            if len(save_path) > 0:
                # Save model
                model_str = model_name + "-lr" + str(int(round(lr*10000))) + ".pt"
                torch.save({
                    'seed': seed,
                    'epoch': epoch,
                    'val_metric': best_val_metric,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'lora_dim': lora_dim,
                    'lora_alpha': lora_alpha}, 
                    save_path + model_str)            
    print("Best Val Metric: {}".format(best_val_metric))
    return

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description = "Run meta-lora algorithm using Roberta")
    valid_models = [
        "roberta-large",
        "roberta-base"
    ]
    valid_models.extend([m + "-openai-detector" for m in valid_models])
    parser.add_argument('--model', metavar='m', action = "store", choices = ["roberta-large","roberta-base","roberta-large-openai"], type=str, help = "base model to use", default="roberta-large")
    parser.add_argument("--from_pretrained_base", action = "store", type=str, help = "path to load pretrained saved model",default = "")
    parser.add_argument('--task', metavar = '-t', action= "store", choices = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"], type=str, help = "glue task to use", default = "cola")
    parser.add_argument('--batchsz', metavar='b', action="store", type=int, help= "batch size", default = 32)
    parser.add_argument('--num_epochs', metavar='n', action = "store", type=int, help = "number of epochs", default = 20)
    parser.add_argument('--lr', action="store", type=float, help = "learning rate", default = 3e-4)
    parser.add_argument('--lora_dim', action = "store", type=int, help = "lora adapter dimension", default = 8)
    parser.add_argument('--lora_alpha', action = "store", type=float, help = "lora alpha for scaling", default = 16)
    parser.add_argument('--lora_dropout', action = "store", type=float, help = "lora dropout probability", default = .1)
    parser.add_argument('--device', action="store",help="device to train on",default = "cuda:4")
    parser.add_argument('--save', action = "store", help="save model checkpoint path", default = "")
    parser.add_argument('--seed', action = "store", help="random seed", type = int, default = 613)
    parser.add_argument('--only_qv', action = "store_false", help="only adapt q,v matrices", default = False)
    parser.add_argument('--verbose', action = "store_true", help="verbose output", default = True)

    args = parser.parse_args()
    main(args)