
import os
import random
import sys
import json
import numpy as np
import torch
import click
import copy
import torch.optim as optim

from tqdm import tqdm
from sklearn.metrics import accuracy_score
from allennlp.nn import util as nn_util
from allennlp.common.tqdm import Tqdm
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.iterators import BasicIterator
from allennlp.data.dataset_readers import DatasetReader
from allennlp.data.tokenizers import Token
from allennlp.data.iterators import BasicIterator
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.common.params import Params

from dataset import SnipDataset
from model.bilinear import BiLinear
from model.label_encoder import get_label_encoder
from example_encoder.text_encoder import TextEncoder


INPUT_DIM = 300
HIDDEN_DIM = 32
ATTN_DIM = 20
EMBEDDING_DIM = 300
BATCH_SIZE=32

DIR_PATH = os.path.dirname(os.path.realpath(__file__))

manualSeed = 1

np.random.seed(manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

# if you are suing GPU
torch.cuda.manual_seed(manualSeed)
torch.cuda.manual_seed_all(manualSeed)

torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

@click.command()
@click.option('--dataset', help='name of the dataset')
@click.option('--label_encoder_type', help='name of the label encoder')
@click.option('--seed', default=0, type=int, help='seed value; default')
@click.option('--lr', default=0.001, type=float)
@click.option('--weight-decay', default=0.00001, type=float)
@click.option('--joint-dim', default=16, type=int)
@click.option('--bases', default=1, type=int, help='no of bases for rgcn')
@click.option('--gpu', default=0, type=int, help='gpu id')
def main(dataset, label_encoder_type, seed, lr, weight_decay,
         joint_dim, bases, gpu):
    """The function is used to setup and train the model for the dataset
    and the encoder type; the function trains a bilinear model
    with a bilstm text encoder with the label encoder mentioned in the
    parameter

    Arguments:
        dataset {str} -- name of the dataset
        label_encoder_type {str} -- name of the label encoder
    """

    eprint("*"*20)
    eprint("Training Details")
    eprint("DATASET: ", dataset)
    eprint("ENCODER: ", label_encoder_type)
    eprint("*"*20)

    set_seed(seed)
    # filter rel, unk rel, add weight don't really do anything
    options = {
        'lr': lr,
        'joint': joint_dim,
        'weight_decay': weight_decay,
        'seed': seed,
        'graph_path': '../data/subgraphs/snips_graph',
        'bases': bases,
        'gpu': gpu
    }

    # get the vocab and everything else for training the models
    model, iterator, optimizer, train_dataset, test_dataset = setup(dataset, label_encoder_type, options)

    dev = []
    train = []
    for instance in train_dataset:
        if instance.fields['labels'].label in model.seen_classes:
            train.append(instance)
        else:
            dev.append(instance)
    # relace 3 with 0 and 4 with 1
    r = {3: 0, 4: 1}
    new_dev = []
    for instance in dev:
        instance.fields['labels']._label_id = r[instance.fields['labels'].label]
        instance.fields['labels'].label = r[instance.fields['labels'].label]
        new_dev.append(instance)

    datasets = (train, new_dev, test_dataset)
    train_epochs(model, iterator, optimizer, datasets, epochs=10)

def set_seed(seed_val):
    np.random.seed(seed_val)
    random.seed(seed_val)
    torch.manual_seed(seed_val)


def setup(dataset='snips', label_encoder_type='transformer-self', options=None):
    if torch.cuda.is_available():
        device = torch.device('cuda:'+str(options['gpu']))
        cuda_device = options['gpu']
    else:
        device = torch.device('cpu')
        cuda_device = -1

    reader = SnipDataset()
    print("Reading data...")

    train_path = os.path.join(DIR_PATH, 'data/train_shuffle.txt')
    test_path = os.path.join(DIR_PATH, 'data/test.txt')

    train_dataset = reader.read(train_path)
    test_dataset = reader.read(test_path)

    vocab = Vocabulary.from_instances(train_dataset + test_dataset)

    # create the iterator
    iterator = BasicIterator(batch_size=BATCH_SIZE)
    iterator.index_with(vocab)

    print("Loading GloVe...")
    # token embed
    token_embed_path = os.path.join(DIR_PATH, 'data/word_emb.pt')
    if os.path.exists(token_embed_path):
        token_embedding = torch.load(token_embed_path)
    else:
        token_embedding = Embedding.from_params(vocab=vocab,
                                                params=Params({
                                            "pretrained_file": "(http://nlp.stanford.edu/data/glove.840B.300d.zip)#glove.840B.300d.txt",
                                            "embedding_dim": EMBEDDING_DIM,
                                            "trainable": False}))
        torch.save(token_embedding, token_embed_path)


    print("word embeddings created...")
    word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})

    # create the text encoder
    print("Loading the text encoder...")
    text_encoder = TextEncoder(word_embeddings, INPUT_DIM, HIDDEN_DIM, ATTN_DIM)

    print("Loading the label encoder...")
    label_encoder, vocab = get_label_encoder(label_encoder_type, vocab, device, options=options)

    print("Instantiating the BiLinear Model...")
    seen_classes = [0, 1, 2]
    dev_classes = [3, 4]
    unseen_classes = [5, 6]

    options['seen_classes'] = seen_classes
    options['device'] = device
    options['cuda_device'] = cuda_device
    options['dataset'] = 'snips'
    options['label_encoder_type'] = label_encoder_type
    options['unseen_classes'] = unseen_classes
    options['dev_classes'] = dev_classes

    # label_encoder.label_tensor = label_encoder.label_tensor.to(device)
    model = BiLinear(text_encoder, label_encoder, vocab, options)
    model.to(device)
    # create directory for saving the model
    if not os.path.exists(os.path.join(DIR_PATH, 'data/model/snips/')):
        os.makedirs(os.path.join(DIR_PATH,'data/model/snips/'))

    if cuda_device == 0:
        model = model.cuda(cuda_device)

    #
    lr = options['lr']
    weight_decay = options['weight_decay']

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    return model, iterator, optimizer, train_dataset, test_dataset


def train_epochs(model, iterator, optimizer, datasets, epochs=10):
    """The function is used to train the model for a given number of epochs

    Arguments:
        model {nn.Modul} -- the model to be trained; bilinear model
        iterator {BucketIterator} -- the iterator for batching
        optimizer {nn.Optimizer} -- Adam optimizer
        train_dataset -- the yahoo dataset
        dev_dataset  -- the dev yahoo dataset; used to save the best preds
        test_dataset -- the common test dataset
    """
    val_loss = []

    train_dataset, dev_dataset, test_dataset = tuple(datasets)

    print(f"*"*20)
    print(f"Training Details")
    print(f"DATASET: ", model.dataset)
    print(f"ENCODER: ", model.label_encoder_type)
    print(f"*"*20)
    train_graph = model.label_encoder.adj_lists[0]
    dev_graph = model.label_encoder.adj_lists[1]
    test_graph = model.label_encoder.adj_lists[2]

    best_model_wts = copy.deepcopy(model.state_dict())
    seed = model.options['seed']

    for epoch in range(epochs):
        print(f"**** Epoch {epoch} ****")
        # change graph
        change_graph(model, train_graph)
        model, optimizer = train_model(model, train_dataset, iterator, optimizer)

        # TODO: compute the validation loss on the dev set
        change_graph(model, dev_graph)
        loss = compute_loss(model, dev_dataset, iterator, dev=True)
        print(f"Val Loss {loss}")

        # TODO: save model if least val_loss
        val_loss, best_model_wts = save_model(model, loss, val_loss,
                                              model.options['seed'], best_model_wts)

        change_graph(model, test_graph)
        test_model(model, test_dataset, iterator, model.seen_classes,
                   model.cuda_device, epoch=epoch, seed=model.options['seed'])

    # load the best model and predict
    model.load_state_dict(best_model_wts)
    change_graph(model, test_graph)

    predictions = test_model(model, test_dataset, iterator,
                             model.seen_classes, model.cuda_device,
                             epoch=epoch, seed=model.options['seed'])

    if model.label_encoder_type in ['rgcn']:
        torch.save(predictions, os.path.join(DIR_PATH,
                                            'data/model/snips/'+model.label_encoder_type+'_basis_'+str(model.options['bases'])+ \
                                                '_pred_seed_'+str(seed)+'.pt'))

    torch.save(predictions, os.path.join(DIR_PATH, 'data/model/snips/'+model.label_encoder_type+'_pred_seed_'+str(seed)+'.pt'))
    grid_save(predictions, model, seed)
    print('done!')

def change_graph(model, adj_lists):
    model.label_encoder.enc2.adj_lists = adj_lists
    model.label_encoder.enc1.adj_lists = adj_lists


def grid_save(predictions, model, seed):
    #
    print('get the file path for saving')
    save_path = get_model_path(model, seed)
    save_path += '.pred'

    print(f'save with .pred extention at  {save_path}')
    torch.save(predictions, save_path)


def train_model(model, dataset, iterator, optimizer):
    """The function is used to train one epoch

    Arguments:
        model {nn.Module} -- the bilinear model
        dataset {dataset} -- the dataset loaded with allennlp loader
        iterator {Iterator} -- the bucket iterator
        cuda_device {int} -- cuda device

    Returns:
        nn.Module -- the bilinear model
    """
    model.train()
    total_batch_loss = 0.
    generator_tqdm = Tqdm.tqdm(iterator(dataset, num_epochs=1, shuffle=False),
                                total=iterator.get_num_batches(dataset))
    # i = 0
    for batch in generator_tqdm:
        # print("still here", end='\r')
        optimizer.zero_grad()
        batch = nn_util.move_to_device(batch, model.cuda_device)
        output = model(batch['sentence'], batch['labels'], train=True)
        total_batch_loss += output['loss'].item()
        output['loss'].backward()

        # there was a nan; checking if clipping helps
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)

        optimizer.step()

    print(f"Train loss = {total_batch_loss}")

    return model, optimizer


def compute_loss(model, dataset, iterator, train=False, dev=False):
    """The function computes loss for the dataset (either train or dev).
    Based on the dev loss, we will save the model.

    Arguments:
        model {nn.Module} -- the bilinear or gile model
        dataset {nn.Dataset} -- the train or dev dataset
        iterator -- the bucket iterator with the vocab indexed

    Keyword Arguments:
        train {bool} -- indicates if we need to compute masked softmax (default: {False})

    Returns:
        float -- returns the total loss
    """
    model.eval()
    loss = 0.
    with torch.no_grad():
        generator_tqdm = Tqdm.tqdm(iterator(dataset, num_epochs=1, shuffle=False),
                                    total=iterator.get_num_batches(dataset))
        for batch in generator_tqdm:
            batch = nn_util.move_to_device(batch, model.cuda_device)
            # batch['labels'] = batch['labels'] - 3
            output = model(batch['sentence'], batch['labels'], train=train, dev=dev)
            loss += output['loss'].item()

    return loss


def test_model(model, dataset, iterator, seen_classes, cuda_device, print_result=True, epoch=0, seed=0):
    """The funciton is used to test the model on the dev/test set;
    We compute the seen/unseen results and also, compute the resuts on
    only unseen classes.

    Arguments:
        model {nn.Module} -- The torch bilinear model
        dataset {dataset} -- the dataset reader
        iterator  -- the allennlp iterator; mostly this is the bucket iterator

    Returns:
        dict -- the dictionary containing all the predictions
    """
    model.eval()
    all_preds = []
    all_true = []
    all_unseen_preds = []
    output_logits = []

    # predicting on the dev/test set
    with torch.no_grad():
        generator_tqdm = Tqdm.tqdm(iterator(dataset, num_epochs=1, shuffle=False),
                                    total=iterator.get_num_batches(dataset))
        for batch in generator_tqdm:
            batch = nn_util.move_to_device(batch, cuda_device)
            output = model(batch['sentence'])
            preds = torch.argmax(output['logits'], dim=1)
            all_true += batch['labels'].cpu().numpy().tolist()
            all_preds +=  preds.cpu().numpy().tolist()
            output_logits += output['logits'].cpu().numpy().tolist()

            # just unseen
            unseen_preds = torch.argmax(output['unseen_probs'], dim=1)
            all_unseen_preds += unseen_preds.cpu().numpy().tolist()

    gen_unseen_acc = accuracy_score(all_true, all_preds)
    unsee_acc = accuracy_score(all_true, all_unseen_preds)
    print(f'generalized unseen acc = {gen_unseen_acc:.4f}')
    print(f'unseen acc             = {unsee_acc: .4f}')

    torch.save(all_preds, os.path.join(DIR_PATH, "data/model/snips/"+model.label_encoder_type+"_gen_pred"+str(epoch)+".pt"))

    torch.save(all_unseen_preds, os.path.join(DIR_PATH, "data/model/snips/"+model.label_encoder_type+"_unseen_pred"+str(epoch)+".pt"))

    torch.save(output_logits, os.path.join(DIR_PATH, "data/model/snips/"+model.label_encoder_type+"_logits"+str(epoch)+".pt"))

    return {
        'gen_unseen_acc': gen_unseen_acc,
        'unseen_acc': unsee_acc
    }


def get_model_path(model, seed):

    save_path = os.path.join(DIR_PATH, 'data/model/snips/')

    # model label encoder type
    _path = model.label_encoder_type

    # learning rate
    _path += '_lr_'+ str(model.options['lr'])

    # weight_decay
    _path += '_decay_'+str(model.options['weight_decay'])

    # joint dim
    _path += '_joint_'+str(model.options['joint'])

    if model.label_encoder_type in ['rgcn']:
        _path += '_bases_' + str(model.options['bases'])

    # seed
    _path += '_seed_'+str(seed)

    save_path = os.path.join(save_path, _path)

    return save_path


def save_model(model, loss, val_loss, seed, best_model):
    _save = False
    if not val_loss:
        _save = True
    else:
        if loss < min(val_loss):
            _save = True

    save_path = get_model_path(model, seed)

    json_path = save_path + '.json'
    save_path += '.pt'
    if _save:
        best_model = copy.deepcopy(model.state_dict())
        print(save_path)
        print("saving model!")
        with open(save_path, 'wb') as f:
            torch.save(model.state_dict(), f)

    val_loss.append(loss)

    with open(json_path, 'w+') as fp:
        json.dump(val_loss, fp)

    return val_loss, best_model


def eprint(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)

if __name__ == "__main__":
    main()
