from datasets import load_dataset
from evaluate import load
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from utils import add_adapters, set_active_task, freeze_base_thaw_adapters, instantiate_model
from copy import deepcopy
import argparse
import csv

def main(args):
    # Parse arguments
    model_name = args.name
    model_path = args.model_path
    task = args.task
    batch_size = args.batchsz
    device = args.device
    print_eval = args.print_eval
    out_path = args.out_path

    # Set seed for dataset shuffle
    torch.manual_seed(613)

    # Load task data, metrics, and pretrained model
    actual_task = "mnli" if task == "ax" else task
    task_data = load_dataset("glue", task, cache_dir="/var/local/nameredacted/.cache/huggingface/datasets")
    metric = load("glue",task, cache_dir="/var/local/nameredacted/.cache/huggingface/metrics")

    # Preprocess data
    padding_side = "right"

    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=padding_side, cache_dir="/var/local/nameredacted/.cache/huggingface/tokenizers")
    if getattr(tokenizer, "pad_token_id") is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    task_to_keys = {
        "cola": ("sentence", None),
        "mnli": ("premise", "hypothesis"),
        "mrpc": ("sentence1", "sentence2"),
        "qnli": ("question", "sentence"),
        "qqp": ("question1", "question2"),
        "rte": ("sentence1", "sentence2"),
        "sst2": ("sentence", None),
        "stsb": ("sentence1", "sentence2"),
        "wnli": ("sentence1", "sentence2"),
    }
    sentence_keys = task_to_keys[actual_task]
    def tokenize_function(examples):
        if sentence_keys[1] is None:
            return tokenizer(examples[sentence_keys[0]], truncation=True)
        return tokenizer(examples[sentence_keys[0]], examples[sentence_keys[1]], truncation=True, max_length=None)
    

    rm_cols = [key for key in sentence_keys if key is not None]
    tokenized_dataset = task_data.map(tokenize_function, batched=True, remove_columns=rm_cols)

    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
    # transformers library
    tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
    def collate_fn(examples):
        return tokenizer.pad(examples, padding="longest", return_tensors="pt")

    # Load model
    num_labels = 3 if actual_task =="mnli" else 1 if actual_task=="stsb" else 2
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels = num_labels, cache_dir="/var/local/nameredacted/.cache/huggingface/transformers")
    checkpt = torch.load(model_path, map_location="cpu")
    add_adapters(model, adapter_dim=checkpt['lora_dim'], num_tasks=1, alpha=checkpt['lora_alpha'],p_dropout=0)
    instantiate_model(model,checkpt['model_state_dict'])

    model.to(device)
    model.eval()

    if print_eval:
        eval_dataloader = DataLoader(tokenized_dataset["validation"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size)
        for batch in eval_dataloader:
            batch.to(device)
            with torch.no_grad():
                outputs = model(input_ids = batch['input_ids'], attention_mask = batch['attention_mask'])
            predictions = outputs.logits.argmax(dim=-1)
            predictions, references = predictions, batch["labels"]
            metric.add_batch(
                predictions=predictions,
                references=references,
            )
        print("eval metric: {}".format(metric.compute()))

    test_dataloader = DataLoader(tokenized_dataset["test"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size)
    pred_list = []
    idx_list = []
    for batch in test_dataloader:
        batch.to(device)
        with torch.no_grad():
            outputs = model(input_ids = batch['input_ids'], attention_mask = batch['attention_mask'])
        predictions = outputs.logits.argmax(dim=-1)
        pred_list.extend(predictions.tolist())
        idx_list.extend(batch['idx'].tolist())
    
    task_to_task_str = {
        "cola": "CoLA",
        "mnli": "MNLI-m",
        "mnli-mm": "MNLI-mm",
        "mrpc": "MRPC",
        "qnli": "QNLI",
        "qqp": "QQP",
        "rte": "RTE",
        "sst2": "SST-2",
        "stsb": "STS-B",
        "wnli": "WNLI",
        "ax": "AX"
    }
    task_str = task_to_task_str[task]

    with open(out_path + task_str + ".tsv", 'w', newline='') as tsvfile:
        writer = csv.writer(tsvfile, delimiter='\t', lineterminator='\n')
        writer.writerow(["index","prediction"])
        for i in range(len(idx_list)):
            writer.writerow([idx_list[i],pred_list[i]])
    return

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description = "Evlauate a model on Glue")
    parser.add_argument('--name', metavar='m', choices = ["roberta-large","roberta-base","roberta-large-openai"], type=str, help = "base model to use", default="roberta-large")
    parser.add_argument('--model_path', metavar='-p', action="store", type=str, help = "path to model")
    parser.add_argument('--task', metavar = '-t', action= "store", choices = ["ax","cola", "mnli", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"], type=str, help = "glue task to use")
    parser.add_argument('--batchsz', metavar='-b', nargs=1, type=int, help= "batch size", default = 32)
    parser.add_argument('--device', action="store",help="device to evaluate on",default = "cuda:4")
    parser.add_argument('--print_eval', action="store_true",help="whether to print eval metric")
    parser.add_argument('--out_path', action="store",help="path to output directory", default = "")

    args = parser.parse_args()
    main(args)