import datasets
from datasets import load_dataset,concatenate_datasets
from evaluate import load
import torch
from numpy import round
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.nn import NLLLoss,LogSoftmax
from transformers import RobertaTokenizer, RobertaForSequenceClassification, get_linear_schedule_with_warmup
from tqdm import tqdm
from conv_utils import preprocess, cluster_persona, read_langs, Lang, filter_data, get_dict_lists, get_dloader_list, file_to_dict_lists

import sys
sys.path.insert(0, '../')
from utils import add_adapters, set_active_task, freeze_base_thaw_adapters, instantiate_base_model

import argparse
import numpy as np


def main(args):
    # Parse arguments
    model_name = args.model
    pretrained_base_path = args.from_pretrained_base
    split = args.split
    if args.persona_nums_rng[1] > args.persona_nums_rng[0]:
        p_idxs = np.arange(args.persona_nums_rng[0],args.persona_nums_rng[1]+1)
    else:
        p_idxs = args.persona_nums
    batch_size = args.batchsz
    num_epochs = args.num_epochs
    lr = args.lr
    lr_sched_mult = args.lr_sched_mult
    lora_dim = args.lora_dim
    lora_alpha = args.lora_alpha
    lora_dropout = args.lora_dropout
    device = args.device
    save_path = args.save
    seed = args.seed
    only_qv = args.only_qv
    verb = args.verbose
    num_valid_valid = args.num_valid

    if verb:
        print(f"Command Line Args: \n {args}")

    # Set seed for dataset shuffle
    torch.manual_seed(seed)

    if split == "train":
        file = '/var/local/nameredacted/data/raw_download/train_self_original.txt'
    else:
        file = '/var/local/nameredacted/data/raw_download/valid_self_original.txt'

    print(f"Running on {split} Data")

    train_dict_list, valid_dict_list = file_to_dict_lists(file, num_valid_valid)

    print("Size of Trains:")
    print([len(train_dict_list[p_idx]['ans']) for p_idx in p_idxs])

    print(f"Size of Valids: ")
    print([len(valid_dict_list[p_idx]['ans']) for p_idx in p_idxs])

    tokenizer = RobertaTokenizer.from_pretrained('roberta-base',padding_side="right",cache_dir="/var/local/nameredacted/.cache/huggingface/tokenizers")
    model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2,cache_dir="/var/local/nameredacted/.cache/huggingface/transformers")

    tr_dloader_list = get_dloader_list([train_dict_list[p_idx] for p_idx in p_idxs],tokenizer,batch_size)
    val_dloader_list = get_dloader_list([valid_dict_list[p_idx] for p_idx in p_idxs],tokenizer,batch_size)

    num_tasks = len(tr_dloader_list)
    add_adapters(model, adapter_dim=lora_dim, num_tasks=num_tasks, num_labels_list=[2 for i in range(num_tasks)], alpha=lora_alpha,p_dropout=lora_dropout, only_qv = only_qv)
    
    
    if len(pretrained_base_path) > 0:
        checkpoint_base = torch.load(pretrained_base_path,map_location="cpu")
        instantiate_base_model(model,checkpoint_base['model_state_dict'])
        base_metrics = checkpoint_base["val_metrics"]
        print("Loading Saved Model...")
        print(f"Base Validation Metrics: {base_metrics}")
    else:
        print("Training from Downloaded Model...")

    optimizers = [AdamW(params=model.parameters(), lr=lr) for i in range(num_tasks)]
    lr_schedulers = [get_linear_schedule_with_warmup(optimizer=optimizers[i], num_warmup_steps=lr_sched_mult*0.06 * (len(tr_dloader_list[i]) * num_epochs), num_training_steps=(len(tr_dloader_list[i]) * num_epochs)) for i in range(num_tasks)]
    log_softmax = LogSoftmax(dim = 1)
    loss_fn = NLLLoss()

    freeze_base_thaw_adapters(model)
    
    model.to(device)
    best_val_accs = np.zeros(num_tasks)
    best_avg_acc = 0
    for epoch in range(num_epochs):
        model.train()
        losses = np.zeros(num_tasks)
        tr_accs = np.zeros(num_tasks)
        
        for task_idx in range(num_tasks):
            epoch_loss = 0
            tr_num_correct = 0
            tr_num = 0
            val_num_correct = 0
            val_num = 0
            set_active_task(model,task_idx)
            for _, batch in enumerate(tqdm(tr_dloader_list[task_idx])):
                (sz, _, pad_length) = batch['input_ids'].shape

                input_ids = batch['input_ids'].reshape((-1,pad_length)).to(device)
                attention_mask = batch['attention_mask'].reshape((-1,pad_length)).to(device)
                ans = batch['ans'].to(device)

                outputs = model(input_ids = input_ids, attention_mask = attention_mask)
                logits = outputs.logits

                targets = torch.zeros((sz,20),dtype=torch.long,requires_grad = False).to(device)
                for i in range(sz):
                    targets[i][ans[i]] = 1
                targets = targets.reshape(-1)

                loss = loss_fn(log_softmax(logits), targets)
                epoch_loss += loss.detach()

                logits_resized = outputs.logits.detach().reshape(sz,20,2)
                preds = torch.argmax(logits_resized[:,:,1],dim=1)
                tr_num += sz
                tr_num_correct += torch.sum(torch.eq(preds,ans))

                loss.backward()
                optimizers[task_idx].step()
                optimizers[task_idx].zero_grad()
                lr_schedulers[task_idx].step()
                del outputs
                del loss
                del batch

            losses[task_idx] = epoch_loss.to("cpu").numpy()
            tr_accs[task_idx] = (tr_num_correct/tr_num).to("cpu").numpy()

        model.eval()
        val_accs = np.zeros(num_tasks)
        for task_idx in range(num_tasks):
            for batch in tqdm(val_dloader_list[task_idx]):
                (sz, _, pad_length) = batch['input_ids'].shape

                input_ids = batch['input_ids'].reshape((-1,pad_length)).to(device)
                attention_mask = batch['attention_mask'].reshape((-1,pad_length)).to(device)
                ans = batch['ans'].to(device)

                with torch.no_grad():
                    outputs = model(input_ids = input_ids, attention_mask = attention_mask)
                
                logits_resized = outputs.logits.detach().reshape(sz,20,2)
                preds = torch.argmax(logits_resized[:,:,1],dim=1)

                val_num += sz
                val_num_correct += torch.sum(torch.eq(preds,ans))

            val_accs[task_idx] = (val_num_correct/val_num).to("cpu").numpy()

            if val_accs[task_idx] > best_val_accs[task_idx]:
                best_val_accs[task_idx] = val_accs[task_idx]

        mean = np.mean(val_accs)
        if mean > best_avg_acc:
            best_avg_acc = mean
        if verb:
            print(f"Epoch {epoch} Training Losses:", losses)
            print(f"Epoch {epoch} Training Accuracys:", tr_accs)
            print(f"Epoch {epoch} Validation Accuracy:", val_accs)
                   
    print("Best Val Accuracys: {}".format(best_val_accs))
    print(f"Best Avg Accuracy: {best_avg_acc}")
    return

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description = "Run meta-lora algorithm using Roberta")
    valid_models = [
        "roberta-large",
        "roberta-base"
    ]
    valid_models.extend([m + "-openai-detector" for m in valid_models])
    parser.add_argument('--model', metavar='m', action = "store", choices = ["roberta-large","roberta-base","roberta-large-openai"], type=str, help = "base model to use", default="roberta-large")
    parser.add_argument("--from_pretrained_base", action = "store", type=str, help = "path to load pretrained saved model",default = "")
    parser.add_argument('--split', metavar='s', action = "store", type=str, help = "train or valid", default="valid")
    parser.add_argument('--persona_nums', metavar='p', action = "store", type=int, nargs='+', help = "which persona indices to use", default=0)
    parser.add_argument('--persona_nums_rng', metavar='p', action = "store", type=int, nargs=2, help = "start and end persona indices", default=[0,0])
    parser.add_argument('--batchsz', metavar='b', action="store", type=int, help= "batch size", default = 32)
    parser.add_argument('--num_valid', metavar='v', action = "store", type=int, default = 1)
    parser.add_argument('--num_epochs', metavar='n', action = "store", type=int, help = "number of epochs", default = 20)
    parser.add_argument('--lr', action="store", type=float, help = "learning rate", default = 3e-4)
    parser.add_argument('--lr_sched_mult', action="store", type=float, help = "learning rate schedule multiplier", default = 1)
    parser.add_argument('--lora_dim', action = "store", type=int, help = "lora adapter dimension", default = 8)
    parser.add_argument('--lora_alpha', action = "store", type=float, help = "lora alpha for scaling", default = 16)
    parser.add_argument('--lora_dropout', action = "store", type=float, help = "lora dropout probability", default = .1)
    parser.add_argument('--device', action="store",help="device to train on",default = "cuda:0")
    parser.add_argument('--save', action = "store", help="save model checkpoint path", default = "")
    parser.add_argument('--seed', action = "store", help="random seed", type = int, default = 613)
    parser.add_argument('--only_qv', action = "store_false", help="only adapt q,v matrices", default = False)
    parser.add_argument('--verbose', action = "store_true", help="verbose output", default = True)

    args = parser.parse_args()
    main(args)