import os
import json
import torch
import torch.nn.functional as F
import pandas as pd
from class_encoder.dgp import DGP
from class_encoder.gcnz import GCNZ
from class_encoder.sgcn import SGCN

DIR_PATH = os.path.dirname(os.path.realpath(__file__))

def get_label_encoder(label_encoder_type, vocab, device, options=None):

    # load the wn mapping
    wn_mapping = pd.read_csv(os.path.join(DIR_PATH, 'snips_mapping.csv'))

    if label_encoder_type == 'gcnz':
        graph = json.load(open(os.path.join(DIR_PATH,
                                        '../data/induced_graph.json'), 'r'))
        wnids = graph['wnids']
        n = len(wnids)
        edges = graph['edges']

        # wn_mapping = pd.read_csv(os.path.join(DIR_PATH, 'mapping/'+ dataset +'/wnid_mapping.csv'))
        wnid_to_idx = dict([(wnid,idx) for idx, wnid in enumerate(wnids)])

        label_idx = [wnid_to_idx[wn_mapping['wnid'][i]] for i in range(7)]
        edges = edges + [(v, u) for (u, v) in edges]
        edges = edges + [(u, u) for u in range(n)]

        word_vectors = torch.tensor(graph['vectors'])
        word_vectors = F.normalize(word_vectors)

        label_encoder = GCNZ(n, edges, word_vectors, label_idx, device)

    if label_encoder_type == 'dgp':
        graph = json.load(open(os.path.join(DIR_PATH,
                                        '../data/dense_graph.json'), 'r'))
        wnids = graph['wnids']
        n = len(wnids)

        edges_set = graph['edges_set']
        print('edges_set', [len(l) for l in edges_set])

        # this is the K value; this indicates the depth of ancestors and
        # descendants.
        # assuming this is right;
        lim = 4
        for i in range(lim + 1, len(edges_set)):
            edges_set[lim].extend(edges_set[i])
        edges_set = edges_set[:lim + 1]
        print('edges_set', [len(l) for l in edges_set])

        word_vectors = torch.tensor(graph['vectors'])
        word_vectors = F.normalize(word_vectors).to(device)

        wnid_to_idx = dict([(wnid,idx) for idx, wnid in enumerate(wnids)])
        label_idx = [wnid_to_idx[wn_mapping['wnid'][i]] for i in range(7)]

        label_encoder = DGP(n, edges_set, word_vectors, label_idx, device)

    if label_encoder_type == 'sgcn':
        graph = json.load(open(os.path.join(DIR_PATH,
                                        '../data/induced_graph.json'), 'r'))
        wnids = graph['wnids']
        n = len(wnids)
        edges = graph['edges']

        edges = edges + [(v, u) for (u, v) in edges]
        edges = edges + [(u, u) for u in range(n)]

        wnid_to_idx = dict([(wnid,idx) for idx, wnid in enumerate(wnids)])
        label_idx = [wnid_to_idx[wn_mapping['wnid'][i]] for i in range(7)]

        word_vectors = torch.tensor(graph['vectors'])
        word_vectors = F.normalize(word_vectors)

        label_encoder = SGCN(n, edges, word_vectors, label_idx, device)

    return label_encoder, vocab
