import os
import json
import argparse

import torch
import pandas as pd

from allennlp.common.params import Params
from allennlp.common.tqdm import Tqdm
from allennlp.nn import util as nn_util
from allennlp.data.iterators.basic_iterator import BasicIterator
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding

from datasets.paths import DATASETS
from datasets.fine_grained import FineEntityTyping
from example_encoder.attentive_ner import MentionEncoder

from utils.eval import get_true_and_prediction, strict, loose_macro, loose_micro
from utils.common import set_seed, get_save_path, init_device, change_graph
from model.label_encoder import get_label_encoder, get_graph
from model.bilinear import BiLinearModel


torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

DIR_PATH = os.path.dirname(os.path.realpath(__file__))
GLOVE_PATH = os.path.join(DIR_PATH, "data/glove.840B.300d.txt")

def setup(all_dataset, options):
    # TODO: change vocab path
    vocab_path = options['vocab_path']

    vocab = Vocabulary.from_files(vocab_path)

    # instantiate iterator
    iterator = BasicIterator(batch_size=1000)
    iterator.index_with(vocab)

    vocab_len = len(vocab.get_index_to_token_vocabulary('tokens'))
    weight_matrix = torch.randn((vocab_len, 300)).to(options['device'])

    # load example encoder
    token_embs = Embedding(vocab_len, 300, weight=weight_matrix,
                           trainable=False)

    word_embs = BasicTextFieldEmbedder({"tokens": token_embs})
    example_encoder = MentionEncoder(word_embs,
                                    input_dim=300,
                                    hidden_dim=100,
                                    attn_dim=100)
    # load the graph
    adj_lists = get_graph(options['graph_path'])

    # load label encoder
    label_encoder = get_label_encoder(options)

    # load bilinear model
    model = BiLinearModel(vocab, example_encoder, label_encoder,
                          adj_lists, options=options)

    return model, iterator


def test_model(model, dataset, iterator, options):
    test_dataset = dataset
    train_graph, test_graph = tuple(model.adj_lists)

    # get the conceptnet idx
    mapping_path = os.path.join(options['graph_path'],
                                'mapping.json')
    mapping = json.load(open(mapping_path))

    label_path = os.path.join(options['dataset_path'],
                              'train_labels.csv')
    train_labels = pd.read_csv(label_path)
    train_len = len(train_labels['LABELS'].to_list())

    label_idx = [mapping[str(idx)] for idx in range(train_len)]
    test_idx = [mapping[str(idx)] for idx in range(len(mapping))]

    seen_idx = [idx for idx in range(train_len)]
    unseen_idx = [idx for idx in range(len(mapping)) if idx not in seen_idx]

    results = []
    #
    change_graph(model, test_graph)
    result = eval_model(model, test_dataset, iterator, test_idx, seen_idx=seen_idx,
                        unseen_idx=unseen_idx)

    return result


def eval_model(model, dataset, iterator, label_idx,
               seen_idx=None, unseen_idx=None):
    model.eval()

    generator_tqdm = Tqdm.tqdm(iterator(dataset, num_epochs=1),
                                total=iterator.get_num_batches(dataset))

    one_hot = torch.Tensor().to(model.device)
    prob_preds = torch.Tensor().to(model.device)

    with torch.no_grad():
        for batch in generator_tqdm:
            batch = nn_util.move_to_device(batch, model.cuda_device)
            logits = model(batch, label_idx)
            prob = torch.sigmoid(logits)

            one_hot = torch.cat([one_hot, batch['labels'].float()], dim=0)
            prob_preds = torch.cat([prob_preds, prob], dim=0)

    result_col = ['overall', 'seen', 'unseen']
    labels = [list(range(len(label_idx))), seen_idx, unseen_idx]

    result = {}

    for i in range(3):
        one_hot_list = one_hot[:, labels[i]].cpu().numpy().tolist()
        prob_list = prob_preds[:, labels[i]].cpu().numpy().tolist()
        true_and_pred = get_true_and_prediction(prob_list, one_hot_list)

        # remove true and pred that are empty
        clean_true_pred = []
        for k in range(len(true_and_pred)):
            if true_and_pred[k][0]:
                clean_true_pred.append(true_and_pred[k])

        strict_acc = strict(clean_true_pred)[2]
        loose_micro_acc = loose_micro(clean_true_pred)[2]
        loose_macro_acc = loose_macro(clean_true_pred)[2]

        result[result_col[i]] = {
            'strict': strict_acc,
            'loose_micro': loose_micro_acc,
            'loose_macro': loose_macro_acc
        }

        print('{}: strict={:.4f}, ' \
            'loose micro={:.4f}, ' \
            'loose macro={:.4f}'.format(result_col[i].upper(), strict_acc,
                                        loose_micro_acc,
                                        loose_macro_acc))

    return result



def load_dataset(options):
    # TODO: do this for ontonotes or other datasets

    dataset_path = options['dataset_path']
    train_path = os.path.join(dataset_path, 'clean_train.json')
    test_path = os.path.join(dataset_path, 'clean_test.json')

    #
    train_df = pd.read_csv(os.path.join(dataset_path, 'train_labels.csv'))
    train_labels = train_df['LABELS'].to_list()

    #
    test_df = pd.read_csv(os.path.join(dataset_path, 'test_labels.csv'))
    test_labels = test_df['LABELS'].to_list()
    all_labels = train_labels + test_labels
    test_to_idx = dict([(label, idx) for idx, label in enumerate(all_labels)])
    test_reader = FineEntityTyping(test_to_idx)
    test_dataset = test_reader.read(test_path)

    return test_dataset

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', help='zero shot dataset')
    parser.add_argument('--label_encoder_type', help='label encoder type')
    parser.add_argument('--seed', default=0, type=int, help='seed no.')
    parser.add_argument('--gpu', default=0, type=int, help='gpu')
    parser.add_argument('--pd1', default=150, type=int)
    parser.add_argument('--pd2', default=64, type=int)
    parser.add_argument('--fh1', default=150, type=int)
    parser.add_argument('--fh2', default=64, type=int)
    parser.add_argument('--decay', default=0., type=float)
    parser.add_argument('--n1', default=0, type=int)
    parser.add_argument('--n2', default=0, type=int)
    args = parser.parse_args()

    device, cuda_device = init_device(args.gpu)

    #
    dataset_path = os.path.join(DIR_PATH, DATASETS[args.dataset]['dataset'])
    vocab_path = os.path.join(DIR_PATH, DATASETS[args.dataset]['vocab_path'])
    graph_path = os.path.join(DIR_PATH, DATASETS[args.dataset]['graph_path'])
    model_path = os.path.join(DIR_PATH, 'data/models/' + args.dataset)
    result_path = os.path.join(DIR_PATH, 'data/results/' + args.dataset)

    #
    options = {
        'seed': args.seed,
        'label_encoder_type': args.label_encoder_type,
        'dataset': args.dataset,
        'gpu': args.gpu,
        'cuda_device': cuda_device,
        'device': device,
        'dataset_path': dataset_path,
        'model_path': model_path,
        'result_path': result_path,
        'vocab_path': vocab_path,
        'graph_path': graph_path,
        'pd1': args.pd1,
        'pd2': args.pd2,
        'fh1': args.fh1,
        'fh2': args.fh2,
        'decay': args.decay,
        'n1': args.n1,
        'n2': args.n2
    }

    # set the seed
    set_seed(args.seed)

    #
    test_dataset = load_dataset(options)

    #
    model, iterator = setup(test_dataset, options)

    # load saved model
    save_path = get_save_path(model_path, options)
    model.load_state_dict(torch.load(save_path, map_location='cpu'))

    #
    model = model.to(device)

    print('dataset: {}'.format(args.dataset))
    print('label encoder: {}'.format(args.label_encoder_type))
    print('seed: {}'.format(args.seed))

    results = test_model(model, test_dataset,
                         iterator, options)

