import re
import torch
import json
import torch
import os

from collections import Counter

from allennlp.modules.token_embedders import Embedding
from allennlp.data.tokenizers import Token
from allennlp.common.params import Params
from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder

# label encoders
from allennlp.data.fields import TextField
from allennlp.data.instance import Instance
from allennlp.data.vocabulary import Vocabulary
from allennlp.nn.util import get_text_field_mask
from allennlp.nn.util import masked_mean

from example_encoder.text_encoder import TextEncoder

from class_encoder.gcn import GCN
from class_encoder.transfomer import TransformerGCN
from class_encoder.rgcn import RGCN
from class_encoder.gat import GAT
from class_encoder.lstm import LSTMGCN

DIR_PATH = os.path.dirname(os.path.realpath(__file__))

INPUT_DIM = 300
HIDDEN_DIM = 100
ATTN_DIM = 100

label_graph_mapping = {0: '/c/en/weather',
    1: '/c/en/music',
    2: '/c/en/restaurant',
    5: '/c/en/book',
    4: '/c/en/movie',
    3: '/c/en/search',
    6: '/c/en/playlist'}


GRAPH_ENCODERS = ['gcn', 'rgcn', 'gat', 'lstm', 'transformer']

def get_label_encoder(label_encoder_type, vocab, device, options=None):
    """The function is used to create the label encoder and modify
    the vocabulary if necessary.

    Arguments:
        label_encoder_type {str} -- the type of label encoder type
        vocab {Vocabulary} -- the vocabulary from the training set
        device {torch.device} -- the device information; cpu or cuda

    Returns:
        tuple -- contains the label encoder and the vocab if changed
    """

    if label_encoder_type in GRAPH_ENCODERS:
        return construct_graph_encoder(vocab, label_encoder_type, options, device)


def convert_index_to_int(adj_lists):
    """Function to convert the node indices to int
    """
    new_adj_lists = {}
    for node, neigh in adj_lists.items():
        new_adj_lists[int(node)] = neigh

    return new_adj_lists


def construct_graph_encoder(vocab, label_encoder_type, options, device):
    print('load train, val and test graph for the fold')
    graph_path = os.path.join(DIR_PATH, options['graph_path'])
    train_adj_lists  = json.load(open(os.path.join(graph_path, 'rw_train_adj_lists.json')))
    val_adj_lists = json.load(open(os.path.join(graph_path, 'rw_dev_adj_lists.json')))
    test_adj_lists = json.load(open(os.path.join(graph_path, 'rw_test_adj_lists.json')))

    adj_lists = [train_adj_lists, val_adj_lists, test_adj_lists]
    for i in range(3):
        adj_lists[i] = convert_index_to_int(adj_lists[i])

    print('load the init embeddings')
    concept_path = os.path.join(options['graph_path'], 'concepts.pt')
    concept_path = os.path.join(DIR_PATH, concept_path)
    init_feats = torch.load(concept_path)

    print('load mapping from label id to concept')
    mapping_path = os.path.join(options['graph_path'], 'mapping.json')
    mapping_path = os.path.join(DIR_PATH, mapping_path)
    mapping = json.load(open(mapping_path))
    if label_encoder_type in ['rgcn']:
        print('pruning graph ')
        # train_ids = options['train_ids']
        concept_ids = [mapping[str(i)] for i in [0, 1, 2]]
        rel_ids = get_rel_ids(adj_lists[0], [50, 100], concept_ids)
        rel_ids.add(50)
        adj_lists[1] = prune_graph(adj_lists[1], rel_ids)
        adj_lists[2] = prune_graph(adj_lists[2], rel_ids)

    print('label tensor')
    label_list = []
    for i in range(7):
        label_list.append(mapping[str(i)])

    print('choose the gnn model')
    if label_encoder_type == 'gcn':
        graph_label_encoder = GCN(init_feats,
                                    label_list,
                                    adj_lists,
                                    device,
                                    options=options)

    if label_encoder_type == 'rgcn':
        graph_label_encoder = RGCN(init_feats,
                                    label_list,
                                    adj_lists,
                                    device,
                                    options=options)

    if label_encoder_type == 'gat':
        graph_label_encoder = GAT(init_feats,
                                    label_list,
                                    adj_lists,
                                    device,
                                    options=options)

    if label_encoder_type == 'transformer':
        graph_label_encoder = TransformerGCN(init_feats,
                                                 label_list,
                                                 adj_lists,
                                                 device,
                                                 gcn=True)

    if label_encoder_type == 'lstm':
        graph_label_encoder = LSTMGCN(init_feats,
                                    label_list,
                                    adj_lists,
                                    device,
                                    options=options)

    print(f'initialized the gnn encoder - {label_encoder_type}')

    return graph_label_encoder, vocab


def get_rel_ids(adj_lists, neigh_sizes, node_ids):
    """Function to get all the rel ids

    Arguments:
        adj_lists {dict} -- dictionary containing list of list
        neigh_sizes {list} -- list containing the sample size of the neighbours
        node_ids {list} -- contains the initial train ids

    Returns:
        set -- returns the set of relations that are part of the training
    """
    all_rels = []
    nodes = node_ids
    for sample_size in neigh_sizes:
        to_neighs = [adj_lists[node] for node in nodes]
        _neighs = [sorted(to_neigh, key=lambda x: x[2], reverse=True)[:sample_size]
                        if len(to_neigh) >= sample_size else to_neigh for to_neigh in to_neighs]
        _node_rel = []
        # nodes = []
        for neigh in _neighs:
            for node, rel, hp in neigh:
                all_rels.append(rel)
                nodes.append(node)

    all_rels = set(all_rels)
    return all_rels


def prune_graph(adj_lists, relations):
    """The function is used to prune graph based on the relations
    that are present in the training

    Arguments:
        adj_lists {dict} -- dictionary containing the graph
        relations {set} -- list of relation ids

    Returns:
        dict -- pruned graph
    """
    pruned_adj_list = {}
    for node, adj in adj_lists.items():
        pruned_adj_list[node] = []
        for neigh_node, rel, hp in adj:
            if rel in relations:
                pruned_adj_list[node].append((neigh_node, rel, hp))

    return pruned_adj_list
