import datasets
from datasets import load_dataset,concatenate_datasets
from evaluate import load
import torch
from numpy import round
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.nn import NLLLoss,LogSoftmax
from transformers import RobertaTokenizer, RobertaForSequenceClassification, get_linear_schedule_with_warmup
from tqdm import tqdm
from conv_utils import preprocess, cluster_persona, read_langs, Lang, filter_data, get_dict_lists, get_dloader_list

import sys
sys.path.insert(0, '../')
from utils import set_dropout, instantiate_base_model

import argparse


def main(args):
    # Parse arguments
    model_name = args.model
    pretrained_base_path = args.from_pretrained_base
    p_idx = args.persona_num
    batch_size = args.batchsz
    num_epochs = args.num_epochs
    lr = args.lr
    p_drop = args.dropout
    device = args.device
    save_path = args.save
    seed = args.seed
    verb = args.verbose
    num_valid_valid = args.num_valid

    if verb:
        print(f"Command Line Args: \n {args}")

    # Set seed for dataset shuffle
    torch.manual_seed(seed)

    file_train = '/var/local/nameredacted/data/raw_download/train_self_original.txt'
    file_val = '/var/local/nameredacted/data/raw_download/valid_self_original.txt'
    cand = {}

    # Dicts where keys are persona info, vals are convo
    train = read_langs(file_train, cand_list=cand, max_line=None)
    valid = read_langs(file_val, cand_list=cand, max_line=None)


    vocab = Lang()
    train = preprocess(train,vocab,False) #{persona:{dial1:[[context,canditate,answer,persona],[context,canditate,answer,persona]]}, dial2:[[context,canditate,answer,persona],[context,canditate,answer,persona]]}}
    valid = preprocess(valid,vocab,False)

    # Can cluster similar personas
    train = filter_data(cluster_persona(train,'train'),cut=1) 
    valid = filter_data(cluster_persona(valid,'valid'),cut=1)

    del train[0]
    del valid[0]

    train_train_dict_list, train_valid_dict_list = get_dict_lists(train,num_valid_valid)
    valid_train_dict_list, valid_valid_dict_list = get_dict_lists(valid,num_valid_valid)

    print(f"Size of Train Train: {len(train_train_dict_list[p_idx]['ans'])}")
    print(f"Size of Train Valid: {len(train_valid_dict_list[p_idx]['ans'])}")

    tokenizer = RobertaTokenizer.from_pretrained('roberta-base',padding_side="right",cache_dir="/var/local/nameredacted/.cache/huggingface/tokenizers")
    model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2,cache_dir="/var/local/nameredacted/.cache/huggingface/transformers")

    set_dropout(model,p_drop)

    print(f"dropout: {model.roberta.embeddings.dropout.p}")

    tt_dloader_list = get_dloader_list(train_train_dict_list[p_idx:p_idx+1],tokenizer,batch_size)
    tv_dloader_list = get_dloader_list(train_valid_dict_list[p_idx:p_idx+1],tokenizer,batch_size)

    # vt_dloader_list = get_dloader_list(valid_train_dict_list[-1:],tokenizer,batch_size)
    # vv_dloader_list = get_dloader_list(valid_valid_dict_list[-1:],tokenizer,batch_size)    
    
    if len(pretrained_base_path) > 0:
        checkpoint_base = torch.load(pretrained_base_path,map_location="cpu")
        instantiate_base_model(model,checkpoint_base['model_state_dict'])
        base_metrics = checkpoint_base["val_metrics"]
        print(f"Base Validation Metrics: {base_metrics}")


    optimizer = AdamW(params=model.parameters(), lr=lr)
    # lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=0.06 * (len(tt_dloader_list[-1]) * num_epochs), num_training_steps=(len(tt_dloader_list[-1]) * num_epochs))
    log_softmax = LogSoftmax(dim = 1)
    loss_fn = NLLLoss()
    
    model.to(device)
    best_acc = 0
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        tr_num_correct = 0
        tr_num = 0
        val_num_correct = 0
        val_num = 0
        for _, batch in enumerate(tqdm(tt_dloader_list[0])):
            (sz, _, pad_length) = batch['input_ids'].shape

            input_ids = batch['input_ids'].reshape((-1,pad_length)).to(device)
            attention_mask = batch['attention_mask'].reshape((-1,pad_length)).to(device)
            ans = batch['ans'].to(device)

            outputs = model(input_ids = input_ids, attention_mask = attention_mask)
            logits = outputs.logits

            targets = torch.zeros((sz,20),dtype=torch.long,requires_grad = False).to(device)
            for i in range(sz):
                targets[i][ans[i]] = 1
            targets = targets.reshape(-1)

            loss = loss_fn(log_softmax(logits), targets)
            epoch_loss += loss.detach()

            logits_resized = outputs.logits.detach().reshape(sz,20,2)
            preds = torch.argmax(logits_resized[:,:,1],dim=1)
            tr_num += sz
            tr_num_correct += torch.sum(torch.eq(preds,ans))

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # lr_scheduler.step()
            del outputs
            del loss
            del batch

        epoch_tr_acc = tr_num_correct/tr_num

        model.eval()
        for batch in tqdm(tv_dloader_list[0]):
            (sz, _, pad_length) = batch['input_ids'].shape

            input_ids = batch['input_ids'].reshape((-1,pad_length)).to(device)
            attention_mask = batch['attention_mask'].reshape((-1,pad_length)).to(device)
            ans = batch['ans'].to(device)

            with torch.no_grad():
                outputs = model(input_ids = input_ids, attention_mask = attention_mask)
            
            logits_resized = outputs.logits.detach().reshape(sz,20,2)
            preds = torch.argmax(logits_resized[:,:,1],dim=1)

            val_num += sz
            val_num_correct += torch.sum(torch.eq(preds,ans))

        epoch_val_acc = val_num_correct/val_num
        if verb:
            print(f"Epoch {epoch} Training Loss:", epoch_loss)
            print(f"Epoch {epoch} Training Accuracy:", epoch_tr_acc)
            print(f"Epoch {epoch} Validation Accuracy:", epoch_val_acc)

        if epoch_val_acc > best_acc:
            best_acc = epoch_val_acc
            if len(save_path) > 0:
                # Save model
                model_str = model_name + "-lr" + str(int(round(lr*10000))) + ".pt"
                torch.save({
                    'seed': seed,
                    'epoch': epoch,
                    'val_acc': best_acc,
                    'tr_acc': epoch_tr_acc,
                    'loss': epoch_loss,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()}, 
                    save_path + model_str)            
    print("Best Val Accuracy: {}".format(best_acc))
    return

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description = "Run meta-lora algorithm using Roberta")
    valid_models = [
        "roberta-large",
        "roberta-base"
    ]
    valid_models.extend([m + "-openai-detector" for m in valid_models])
    parser.add_argument('--model', metavar='m', action = "store", choices = ["roberta-large","roberta-base","roberta-large-openai"], type=str, help = "base model to use", default="roberta-large")
    parser.add_argument("--from_pretrained_base", action = "store", type=str, help = "path to load pretrained saved model",default = "")
    parser.add_argument('--persona_num', metavar='p', action = "store", type=int, help = "which persona indices to use", default=0)
    parser.add_argument('--batchsz', metavar='b', action="store", type=int, help= "batch size", default = 32)
    parser.add_argument('--num_epochs', metavar='n', action = "store", type=int, help = "number of epochs", default = 20)
    parser.add_argument('--num_valid', metavar='v', action = "store", type=int, default = 1)
    parser.add_argument('--lr', action="store", type=float, help = "learning rate", default = 3e-4)
    parser.add_argument('--dropout', action="store", type=float, help = "dropout pct", default = .1)
    parser.add_argument('--device', action="store",help="device to train on",default = "cuda:0")
    parser.add_argument('--save', action = "store", help="save model checkpoint path", default = "")
    parser.add_argument('--seed', action = "store", help="random seed", type = int, default = 613)
    parser.add_argument('--verbose', action = "store_true", help="verbose output", default = True)

    args = parser.parse_args()
    main(args)