import os
import json

from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

from allennlp.common.params import Params
from allennlp.common.tqdm import Tqdm
from allennlp.data import Instance
from allennlp.data.iterators.basic_iterator import BasicIterator
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.nn.util import get_text_field_mask
from allennlp.nn.util import masked_mean
from allennlp.nn import util as nn_util

from utils.common import convert_index_to_int
from class_encoder.transformer import TransformerGCN
from class_encoder.rev_transformer import RevTransformerGCN
from class_encoder.lstm import LSTMGCN
from class_encoder.gcn import InductiveGCN
from class_encoder.gat import GATConv
from class_encoder.rgcn import RGCN
from class_encoder.avg_label import AvgLabel
from class_encoder.description import DescEncoder
from class_encoder.sgcn import SGCN
from class_encoder.gcnz import GCNZ
from class_encoder.dgp import DGP

from example_encoder.text_encoder import TextEncoder

DIR_PATH = os.path.dirname(os.path.realpath(__file__))

def get_label_encoder(options):

    label_encoder_type = options['label_encoder_type']

    if label_encoder_type == 'otyper':
        return avg_label_class_encoder(options)

    elif label_encoder_type == 'dzet':
        return desc_class_encoder(options)

    elif label_encoder_type == 'sgcn':
        return sgcn_class_encoder(options)

    elif label_encoder_type == 'gcnz':
        return gcnz_class_encoder(options)

    elif label_encoder_type == 'dgp':
        return dgp_class_encoder(options)

    else:
        return graph_class_encoder(label_encoder_type, options)


def avg_label_class_encoder(options):
    device = options['device']

    # load
    all_tokens = []
    list_of_token_list = []

    dataset_path = options['dataset_path']

    #
    train_df = pd.read_csv(os.path.join(dataset_path, 'train_labels.csv'))
    train_labels = train_df['LABELS'].to_list()

    #
    test_df = pd.read_csv(os.path.join(dataset_path, 'test_labels.csv'))
    test_labels = test_df['LABELS'].to_list()
    all_labels = train_labels + test_labels

    for i, label in enumerate(all_labels):
        label_words = label.lstrip("/").split("/")
        # the individual labels also have "_" between them
        label_words = [word for partial_label in label_words
                       for word in partial_label.split("_")]
        all_tokens.extend(label_words)
        list_of_token_list.append(label_words)

    label_counter = Counter(all_tokens)
    label_vocab = Vocabulary({'tokens': label_counter})

    token_embedding = Embedding.from_params(vocab=label_vocab,
                                        params=Params({
                                        "pretrained_file": options['glove_path'],
                                        "embedding_dim": 300,
                                        "trainable": False}))

    token_to_idx = label_vocab.get_token_to_index_vocabulary('tokens')

    padded_idx = convert_token_to_idx(list_of_token_list, token_to_idx)
    mask = get_text_field_mask({"tokens": padded_idx})

    padded_embs = token_embedding(padded_idx)

    avg_label_tensor = masked_mean(padded_embs, mask.unsqueeze(-1), dim=1)
    avg_label_emb = nn.Embedding.from_pretrained(avg_label_tensor, freeze=True).to(device)

    avg_label_encoder = AvgLabel(avg_label_emb, device, options)

    return avg_label_encoder


def desc_class_encoder(options):
    device = options['device']
    dataset = options['dataset']
    dataset_path = options['dataset_path']

    with open(os.path.join(DIR_PATH, "../misc_data/" + dataset + "/desc.json")) as fp:
        desc_data = json.load(fp)

    # load the vocab
    words = []
    for desc_tokens in desc_data.values():
        words += desc_tokens

    desc_vocab = Vocabulary({'tokens': Counter(words)})

    desc_emb = Embedding.from_params(vocab=desc_vocab,
                                            params=Params({
                                            "pretrained_file": options['glove_path'],
                                            "embedding_dim": 300,
                                            "trainable": False}))

    word_embeddings = BasicTextFieldEmbedder({"tokens": desc_emb})
    #

    # get the text encoder
    text_encoder = TextEncoder(word_embeddings,
                               input_dim=300,
                               hidden_dim=100,
                               attn_dim=100)

    train_df = pd.read_csv(os.path.join(dataset_path, 'train_labels.csv'))
    train_labels = train_df['LABELS'].to_list()

    test_df = pd.read_csv(os.path.join(dataset_path, 'test_labels.csv'))
    test_labels = test_df['LABELS'].to_list()
    all_labels = train_labels + test_labels

    list_of_token_list = []
    for label in all_labels:
        list_of_token_list.append(desc_data[label])

    token_to_idx = desc_vocab.get_token_to_index_vocabulary()
    description_tensor = convert_token_to_idx(list_of_token_list, token_to_idx).to(device)

    description_dict = {"tokens": description_tensor}

    # TODO: combine the vocab with description vocab
    desc_label_encoder = DescEncoder(text_encoder, description_dict, device)

    return desc_label_encoder


def sgcn_class_encoder(options):
    device = options['device']
    graph = json.load(open(os.path.join(DIR_PATH,
                                        '../data/induced_graph.json'), 'r'))
    wnids = graph['wnids']
    n = len(wnids)
    edges = graph['edges']

    edges = edges + [(v, u) for (u, v) in edges]
    edges = edges + [(u, u) for u in range(n)]

    word_vectors = torch.tensor(graph['vectors'])
    word_vectors = F.normalize(word_vectors)

    return SGCN(n, edges, word_vectors, device)


def gcnz_class_encoder(options):
    device = options['device']
    graph = json.load(open(os.path.join(DIR_PATH,
                                        '../data/induced_graph.json'), 'r'))
    wnids = graph['wnids']
    n = len(wnids)
    edges = graph['edges']

    edges = edges + [(v, u) for (u, v) in edges]
    edges = edges + [(u, u) for u in range(n)]

    word_vectors = torch.tensor(graph['vectors'])
    word_vectors = F.normalize(word_vectors)

    return GCNZ(n, edges, word_vectors, device)


def dgp_class_encoder(options):
    device = options['device']
    graph = json.load(open(os.path.join(DIR_PATH,
                                        '../data/dense_graph.json'), 'r'))
    wnids = graph['wnids']
    n = len(wnids)

    edges_set = graph['edges_set']
    print('edges_set', [len(l) for l in edges_set])

    # this is the K value; this indicates the depth of ancestors and
    # descendants.
    # assuming this is right;
    lim = 4
    for i in range(lim + 1, len(edges_set)):
        edges_set[lim].extend(edges_set[i])
    edges_set = edges_set[:lim + 1]
    print('edges_set', [len(l) for l in edges_set])

    word_vectors = torch.tensor(graph['vectors'])
    word_vectors = F.normalize(word_vectors).to(device)

    return DGP(n, edges_set, word_vectors, device)

def graph_class_encoder(label_type_encoder, options):
    label_encoder_type = options['label_encoder_type']
    device = options['device']
    graph_path = options['graph_path']

    feat_path = os.path.join(graph_path, 'concepts.pt')
    feat = torch.load(feat_path, map_location='cpu')

    adj_lists = get_graph(graph_path)

    if label_encoder_type == 'transformer':
        return TransformerGCN(feat, adj_lists[0], device,
                              options, gcn=True)

    if label_encoder_type == 'lstm':
        return LSTMGCN(feat, adj_lists[0], device,
                              options, gcn=False)

    if label_encoder_type == 'gcn':
        return InductiveGCN(feat, adj_lists[0], device,
                              options)

    if label_encoder_type == 'gat':
        return GATConv(feat, adj_lists[0], device,
                              options)

    if label_encoder_type == 'rgcn':
        return RGCN(feat, adj_lists[0], device,
                              options)


def get_graph(graph_path):
    train_adj_path = os.path.join(graph_path, \
                                 'rw_train_adj_lists.json')
    train_adj_lists = json.load(open(train_adj_path))

    test_adj_path = os.path.join(graph_path, \
                                 'rw_test_adj_lists.json')
    test_adj_lists = json.load(open(test_adj_path))

    train_adj_lists = convert_index_to_int(train_adj_lists)
    test_adj_lists = convert_index_to_int(test_adj_lists)

    return [train_adj_lists, test_adj_lists]


def convert_token_to_idx(list_of_token_list, token_to_idx):
    """The code convert list of string tokens to its ids and returns
    the tensor.

    Arguments:
        list_of_token_list {list} -- list of list containing token strings
        token_to_idx {dict} -- token to id mapping from description vocab

    Returns:
        torch.tensor -- the tensor with the ids
    """
    # pad the tokens as well
    max_length = max([len(tokens) for tokens in list_of_token_list])

    token_idx_list = []

    for tokens in list_of_token_list:
        tokens_idx = [token_to_idx[token] for token in tokens]
        tokens_idx += [token_to_idx['@@PADDING@@']] * (max_length - len(tokens_idx))
        token_idx_list.append(tokens_idx)

    token_tensor = torch.tensor(token_idx_list)

    return token_tensor