import datasets
from datasets import load_dataset,concatenate_datasets
from evaluate import load
import torch
from numpy import round
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.nn import NLLLoss,LogSoftmax
from transformers import RobertaTokenizer, RobertaForSequenceClassification, get_linear_schedule_with_warmup
from tqdm import tqdm
from conv_utils import preprocess, cluster_persona, read_langs, Lang, filter_data, get_dict_lists, get_dloader_list
import numpy as np

import sys
sys.path.insert(0, '../')
from utils import set_dropout, instantiate_base_model

import argparse

def main(args):
    # Parse arguments
    model_name = args.model
    pretrained_base_path = args.from_pretrained_base
    p_idxs = args.persona_nums
    batch_size = args.batchsz
    num_epochs = args.num_epochs
    lr = args.lr
    p_drop = args.dropout
    device = args.device
    save_path = args.save
    seed = args.seed
    verb = args.verbose
    num_valid_valid = args.num_valid

    if verb:
        print(f"Command Line Args: \n {args}")

    # Set seed for dataset shuffle
    torch.manual_seed(seed)

    file_train = '/var/local/nameredacted/data/raw_download/train_self_original.txt'
    # file_val = '/var/local/nameredacted/data/raw_download/valid_self_original.txt'
    cand = {}

    # Dicts where keys are persona info, vals are convo
    train = read_langs(file_train, cand_list=cand, max_line=None)
    # valid = read_langs(file_val, cand_list=cand, max_line=None)


    vocab = Lang()
    train = preprocess(train,vocab,False) #{persona:{dial1:[[context,canditate,answer,persona],[context,canditate,answer,persona]]}, dial2:[[context,canditate,answer,persona],[context,canditate,answer,persona]]}}
    # valid = preprocess(valid,vocab,False)

    # Can cluster similar personas
    train = filter_data(cluster_persona(train,'train'),cut=1) 
    # valid = filter_data(cluster_persona(valid,'valid'),cut=1)

    del train[0]
    # del valid[0]

    train_train_dict_list, train_valid_dict_list = get_dict_lists(train,num_valid_valid)
    # valid_train_dict_list, valid_valid_dict_list = get_dict_lists(valid,num_valid_valid)

    print("Size of Train Trains:")
    print([len(train_train_dict_list[p_idx]['ans']) for p_idx in p_idxs])

    print(f"Size of Train Valids: ")
    print([len(train_valid_dict_list[p_idx]['ans']) for p_idx in p_idxs])

    tokenizer = RobertaTokenizer.from_pretrained('roberta-base',padding_side="right",cache_dir="/var/local/nameredacted/.cache/huggingface/tokenizers")
    model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2,cache_dir="/var/local/nameredacted/.cache/huggingface/transformers")

    set_dropout(model,p_drop)

    tt_dloader_list = get_dloader_list([train_train_dict_list[p_idx] for p_idx in p_idxs],tokenizer,batch_size)
    tv_dloader_list = get_dloader_list([train_valid_dict_list[p_idx] for p_idx in p_idxs],tokenizer,batch_size)

    # vt_dloader_list = get_dloader_list([valid_train_dict_list[p_idx] for p_idx in p_idxs_val],tokenizer,batch_size)
    # vv_dloader_list = get_dloader_list([valid_valid_dict_list[p_idx] for p_idx in p_idxs_val],tokenizer,batch_size)  

    # print("Size of Valid Trains:")
    # print([len(train_train_dict_list[p_idx]['ans']) for p_idx in p_idxs_val])

    # print(f"Size of Valid Valids: ")
    # print([len(train_valid_dict_list[p_idx]['ans']) for p_idx in p_idxs_val])  
    
    if len(pretrained_base_path) > 0:
        checkpoint_base = torch.load(pretrained_base_path,map_location="cpu")
        instantiate_base_model(model,checkpoint_base['model_state_dict'])
        base_metrics = checkpoint_base["val_metrics"]
        print(f"Base Validation Metrics: {base_metrics}")


    optimizer = AdamW(params=model.parameters(), lr=lr)
    lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=0.06 * (len(tt_dloader_list[p_idxs[0]]) * num_epochs*(5/3)), num_training_steps=(len(tt_dloader_list[p_idxs[0]]) * num_epochs*(5/3)))
    log_softmax = LogSoftmax(dim = 1)
    loss_fn = NLLLoss()
    num_tasks = len(tt_dloader_list)
    model.to(device)
    best_avg_acc = 0
    for epoch in range(num_epochs):
        model.train()
        tr_losses = []
        tr_accs = []
        
        for task_idx in range(num_tasks):
            epoch_loss = 0
            tr_num_correct = 0
            tr_num = 0
            val_num_correct = 0
            val_num = 0
            for _, batch in enumerate(tqdm(tt_dloader_list[task_idx])):
                (sz, _, pad_length) = batch['input_ids'].shape

                input_ids = batch['input_ids'].reshape((-1,pad_length)).to(device)
                attention_mask = batch['attention_mask'].reshape((-1,pad_length)).to(device)
                ans = batch['ans'].to(device)

                outputs = model(input_ids = input_ids, attention_mask = attention_mask)
                logits = outputs.logits

                targets = torch.zeros((sz,20),dtype=torch.long,requires_grad = False).to(device)
                for i in range(sz):
                    targets[i][ans[i]] = 1
                targets = targets.reshape(-1)

                loss = loss_fn(log_softmax(logits), targets)
                epoch_loss += loss.detach()

                logits_resized = outputs.logits.detach().reshape(sz,20,2)
                preds = torch.argmax(logits_resized[:,:,1],dim=1)
                tr_num += sz
                tr_num_correct += torch.sum(torch.eq(preds,ans))

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                lr_scheduler.step()
                del outputs
                del loss
                del batch

            tr_losses.append(epoch_loss)
            tr_accs.append(tr_num_correct/tr_num)

        model.eval()
        val_accs = np.zeros(num_tasks)
        for task_idx in range(len(tt_dloader_list)):
            for batch in tqdm(tv_dloader_list[task_idx]):
                (sz, _, pad_length) = batch['input_ids'].shape

                input_ids = batch['input_ids'].reshape((-1,pad_length)).to(device)
                attention_mask = batch['attention_mask'].reshape((-1,pad_length)).to(device)
                ans = batch['ans'].to(device)

                with torch.no_grad():
                    outputs = model(input_ids = input_ids, attention_mask = attention_mask)
                
                logits_resized = outputs.logits.detach().reshape(sz,20,2)
                preds = torch.argmax(logits_resized[:,:,1],dim=1)

                val_num += sz
                val_num_correct += torch.sum(torch.eq(preds,ans))

            val_accs[task_idx] = (val_num_correct/val_num).to("cpu").numpy()

        if np.mean(val_accs) > best_avg_acc:
            best_avg_acc = np.mean(val_accs)

            # Save model
            if len(save_path) > 0:
                tasks_str = "".join([str(t) for t in p_idxs])
                path = save_path + model_name + "-" + tasks_str + ".pt"
                torch.save({
                    'seed': seed,
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'lr': lr,
                    'batchsz': batch_size,
                    'persona_idxs': p_idxs,
                    'val_metrics': val_accs
                    }, path)
        if verb:
            print(f"Epoch {epoch} Training Losses:", tr_losses)
            print(f"Epoch {epoch} Training Accuracys:", tr_accs)
            print(f"Epoch {epoch} Validation Accuracys:", val_accs)
                
    print(f"Best Avg Validation Accuracy: {best_avg_acc}")
    return

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description = "Run meta-lora algorithm using Roberta")
    valid_models = [
        "roberta-large",
        "roberta-base"
    ]
    valid_models.extend([m + "-openai-detector" for m in valid_models])
    parser.add_argument('--model', metavar='m', action = "store", choices = ["roberta-large","roberta-base","roberta-large-openai"], type=str, help = "base model to use", default="roberta-large")
    parser.add_argument("--from_pretrained_base", action = "store", type=str, help = "path to load pretrained saved model",default = "")
    parser.add_argument('--persona_nums', metavar='p', action = "store", type=int, nargs='+', help = "which persona indices to use", default=0)
    parser.add_argument('--batchsz', metavar='b', action="store", type=int, help= "batch size", default = 32)
    parser.add_argument('--num_epochs', metavar='n', action = "store", type=int, help = "number of epochs", default = 20)
    parser.add_argument('--num_valid', metavar='v', action = "store", type=int, default = 1)
    parser.add_argument('--lr', action="store", type=float, help = "learning rate", default = 3e-4)
    parser.add_argument('--dropout', action="store", type=float, help = "dropout pct", default = .1)
    parser.add_argument('--device', action="store",help="device to train on",default = "cuda:0")
    parser.add_argument('--save', action = "store", help="save model checkpoint path", default = "")
    parser.add_argument('--seed', action = "store", help="random seed", type = int, default = 613)
    parser.add_argument('--verbose', action = "store_true", help="verbose output", default = True)

    args = parser.parse_args()
    main(args)