import datasets
from datasets import load_dataset,concatenate_datasets
from evaluate import load
import torch
from numpy import round
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.nn import NLLLoss,LogSoftmax
from transformers import RobertaTokenizer, RobertaForSequenceClassification, get_linear_schedule_with_warmup
from tqdm import tqdm
from conv_utils import preprocess, cluster_persona, read_langs, Lang, filter_data, get_dict_lists, get_dloader_list

import sys
sys.path.insert(0, '../')
from utils import add_adapters, set_active_task, freeze_lora_thaw_base, instantiate_base_model, thaw

import argparse
import numpy as np


def main(args):
    # Parse arguments
    model_name = args.model
    pretrained_base_path = args.from_pretrained_base
    if args.persona_nums_rng[1] > args.persona_nums_rng[0]:
        p_idxs = np.arange(args.persona_nums_rng[0],args.persona_nums_rng[1]+1)
    else:
        p_idxs = args.persona_nums
    batch_size = args.batchsz
    num_epochs = args.num_epochs
    lr = args.lr
    lora_dim = args.lora_dim
    lora_alpha = args.lora_alpha
    lora_dropout = args.lora_dropout
    device = args.device
    save_path = args.save
    seed = args.seed
    only_qv = args.only_qv
    verb = args.verbose
    num_valid_valid = args.num_valid

    if verb:
        print(f"Command Line Args: \n {args}")
        print(f"Persona IDXs: {p_idxs}")

    # Set seed for dataset shuffle
    torch.manual_seed(seed)

    file_train = '/var/local/nameredacted/data/raw_download/train_self_original.txt'
    # file_val = '/var/local/nameredacted/data/raw_download/valid_self_original.txt'
    cand = {}

    # Dicts where keys are persona info, vals are convo
    train = read_langs(file_train, cand_list=cand, max_line=None)
    # valid = read_langs(file_val, cand_list=cand, max_line=None)


    vocab = Lang()
    train = preprocess(train,vocab,False) #{persona:{dial1:[[context,canditate,answer,persona],[context,canditate,answer,persona]]}, dial2:[[context,canditate,answer,persona],[context,canditate,answer,persona]]}}
    # valid = preprocess(valid,vocab,False)

    # Can cluster similar personas
    train = filter_data(cluster_persona(train,'train'),cut=1) 
    # valid = filter_data(cluster_persona(valid,'valid'),cut=1)

    del train[0]
    # del valid[0]

    train_train_dict_list, train_valid_dict_list = get_dict_lists(train,num_valid_valid)
    # valid_train_dict_list, valid_valid_dict_list = get_dict_lists(valid,num_valid_valid)

    print("Size of Train Trains:")
    print([len(train_train_dict_list[p_idx]['ans']) for p_idx in p_idxs])

    print("Total Size of Train Trains:")
    print(sum([len(train_train_dict_list[p_idx]['ans']) for p_idx in p_idxs]))

    print(f"Size of Train Valids: ")
    print([len(train_valid_dict_list[p_idx]['ans']) for p_idx in p_idxs])

    tokenizer = RobertaTokenizer.from_pretrained('roberta-base',padding_side="right",cache_dir="/var/local/nameredacted/.cache/huggingface/tokenizers")
    model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2,cache_dir="/var/local/nameredacted/.cache/huggingface/transformers")

    tt_dloader_list = get_dloader_list([train_train_dict_list[p_idx] for p_idx in p_idxs],tokenizer,batch_size)
    tv_dloader_list = get_dloader_list([train_valid_dict_list[p_idx] for p_idx in p_idxs],tokenizer,batch_size)

    # vt_dloader_list = get_dloader_list([valid_train_dict_list[p_idx] for p_idx in p_idxs],tokenizer,batch_size)
    # vv_dloader_list = get_dloader_list([valid_valid_dict_list[p_idx] for p_idx in p_idxs],tokenizer,batch_size)

    num_tasks = len(tt_dloader_list)
    add_adapters(model, adapter_dim=lora_dim, num_tasks=num_tasks, num_labels_list=[2 for i in range(num_tasks)], alpha=lora_alpha,p_dropout=lora_dropout, only_qv = only_qv)
    thaw(model)
    
    if len(pretrained_base_path) > 0:
        checkpoint_base = torch.load(pretrained_base_path,map_location="cpu")
        instantiate_base_model(model,checkpoint_base['model_state_dict'])
        base_metrics = checkpoint_base["val_metrics"]
        print("Loading Saved Model...")
        print(f"Base Validation Metrics: {base_metrics}")
    else:
        print("Training from Downloaded Model...")

    optimizer = AdamW(params=model.parameters(), lr=lr)
    lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=0.06 * (len(tt_dloader_list[0]) * num_epochs*(5/3)), num_training_steps=(len(tt_dloader_list[0]) * num_epochs*(5/3)))
    log_softmax = LogSoftmax(dim = 1)
    loss_fn = NLLLoss()
    
    model.to(device)
    best_val_accs = np.zeros(num_tasks)
    best_avg_acc = 0
    for epoch in range(num_epochs):
        model.train()
        losses = torch.zeros(num_tasks)
        tr_accs = torch.zeros(num_tasks)
        tr_num_correct = torch.zeros(num_tasks)
        tr_num = torch.zeros(num_tasks)
        
        keep_training_arr = np.ones(num_tasks, dtype = np.bool)
        iter_dloaders = [iter(tqdm(tt_dloader_list[i])) if i == 0 else iter(tt_dloader_list[i]) for i in range(num_tasks)]
        while(np.any(keep_training_arr)):
            for task_idx in range(num_tasks):
                if not keep_training_arr[task_idx]:
                    continue 
                try:
                    batch = next(iter_dloaders[task_idx])
                except StopIteration:
                    keep_training_arr[task_idx] = False
                    break
                set_active_task(model,task_idx)
                (sz, _, pad_length) = batch['input_ids'].shape

                input_ids = batch['input_ids'].reshape((-1,pad_length)).to(device)
                attention_mask = batch['attention_mask'].reshape((-1,pad_length)).to(device)
                ans = batch['ans'].to(device)

                outputs = model(input_ids = input_ids, attention_mask = attention_mask)
                logits = outputs.logits

                targets = torch.zeros((sz,20),dtype=torch.long,requires_grad = False).to(device)
                for i in range(sz):
                    targets[i][ans[i]] = 1
                targets = targets.reshape(-1)

                loss = loss_fn(log_softmax(logits), targets)
                losses[task_idx] += loss.detach().to("cpu")
                
                logits_resized = outputs.logits.detach().reshape(sz,20,2)
                preds = torch.argmax(logits_resized[:,:,1],dim=1)
                tr_num[task_idx] += sz
                tr_num_correct[task_idx] += torch.sum(torch.eq(preds,ans)).to("cpu")

                # Accumulate gradients for each task
                loss.backward()
                del outputs
                del loss
                del batch
                
                # Update parameters using all gradients from a batch
                optimizer.step()
                optimizer.zero_grad()
                lr_scheduler.step()
                
        tr_accs = torch.div(tr_num_correct,tr_num)

        model.eval()
        val_num_correct = torch.zeros(num_tasks)
        val_num = torch.zeros(num_tasks)
        val_accs = torch.zeros(num_tasks)
        for task_idx in range(num_tasks):
            set_active_task(model,task_idx)
            for batch in tqdm(tv_dloader_list[task_idx]):
                (sz, _, pad_length) = batch['input_ids'].shape

                input_ids = batch['input_ids'].reshape((-1,pad_length)).to(device)
                attention_mask = batch['attention_mask'].reshape((-1,pad_length)).to(device)
                ans = batch['ans'].to(device)

                with torch.no_grad():
                    outputs = model(input_ids = input_ids, attention_mask = attention_mask)
                
                logits_resized = outputs.logits.detach().reshape(sz,20,2)
                preds = torch.argmax(logits_resized[:,:,1],dim=1)

                val_num[task_idx] += sz
                val_num_correct[task_idx] += torch.sum(torch.eq(preds,ans)).to("cpu")
            val_accs[task_idx] = val_num_correct[task_idx]/val_num[task_idx]
            if val_accs[task_idx] > best_val_accs[task_idx]:
                best_val_accs[task_idx] = val_accs[task_idx]

        mean = torch.mean(val_accs)
        if mean > best_avg_acc:
            best_avg_acc = mean

            # Save model
            if len(save_path) > 0:
                tasks_str = "".join([str(t) for t in p_idxs])
                path = save_path + model_name + "-" + tasks_str + ".pt"
                torch.save({
                    'seed': seed,
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'lr': lr,
                    'batchsz': batch_size,
                    'persona_idxs': p_idxs,
                    'val_metrics': val_accs
                    }, path)
        if verb:
            print(f"Epoch {epoch} Training Losses:", losses)
            print(f"Epoch {epoch} Training Accuracys:", tr_accs)
            print(f"Epoch {epoch} Validation Accuracy:", val_accs)
                   
    print("Best Val Accuracys: {}".format(best_val_accs))
    print(f"Best Avg Accuracy: {best_avg_acc}")
    return

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description = "Run meta-lora algorithm using Roberta")
    valid_models = [
        "roberta-large",
        "roberta-base"
    ]
    valid_models.extend([m + "-openai-detector" for m in valid_models])
    parser.add_argument('--model', metavar='m', action = "store", choices = ["roberta-large","roberta-base","roberta-large-openai"], type=str, help = "base model to use", default="roberta-large")
    parser.add_argument("--from_pretrained_base", action = "store", type=str, help = "path to load pretrained saved model",default = "")
    parser.add_argument('--persona_nums', metavar='p', action = "store", type=int, nargs='+', help = "list of persona indices", default=0)
    parser.add_argument('--persona_nums_rng', metavar='p', action = "store", type=int, nargs=2, help = "start and end persona indices", default=[0,0])
    parser.add_argument('--batchsz', metavar='b', action="store", type=int, help= "batch size", default = 32)
    parser.add_argument('--num_valid', metavar='v', action = "store", type=int, default = 1)
    parser.add_argument('--num_epochs', metavar='n', action = "store", type=int, help = "number of epochs", default = 20)
    parser.add_argument('--lr', action="store", type=float, help = "learning rate", default = 3e-4)
    parser.add_argument('--lora_dim', action = "store", type=int, help = "lora adapter dimension", default = 8)
    parser.add_argument('--lora_alpha', action = "store", type=float, help = "lora alpha for scaling", default = 16)
    parser.add_argument('--lora_dropout', action = "store", type=float, help = "lora dropout probability", default = .1)
    parser.add_argument('--device', action="store",help="device to train on",default = "cuda:0")
    parser.add_argument('--save', action = "store", help="save model checkpoint path", default = "")
    parser.add_argument('--seed', action = "store", help="random seed", type = int, default = 613)
    parser.add_argument('--only_qv', action = "store_false", help="only adapt q,v matrices", default = False)
    parser.add_argument('--verbose', action = "store_true", help="verbose output", default = True)

    args = parser.parse_args()
    main(args)