import json
import torch
import numpy as np

from nltk.corpus import wordnet as wn

class GloVe():
    def __init__(self, file_path):
        self.dimension = None
        self.embedding = dict()

        with open(file_path) as fp:
            for line in fp:
                fields = line.rstrip().split(' ')
                vector = torch.FloatTensor(np.asarray(fields[1:],
                                                      dtype='float32'))
                self.embedding[fields[0]] = vector
                if self.dimension is None:
                    self.dimension = len(vector)

    def _fix_word(self, word):
        terms = word.replace('_', ' ').split(' ')
        ret = self.zeros()
        cnt = 0
        for term in terms:
            v = self.embedding.get(term)
            if v is None:
                subterms = term.split('-')
                subterm_sum = self.zeros()
                subterm_cnt = 0
                for subterm in subterms:
                    subv = self.embedding.get(subterm)
                    if subv is not None:
                        subterm_sum += subv
                        subterm_cnt += 1
                if subterm_cnt > 0:
                    v = subterm_sum / subterm_cnt
            if v is not None:
                ret += v
                cnt += 1
        return ret / cnt if cnt > 0 else None

    def __getitem__(self, words):
        if type(words) is str:
            words = [words]
        ret = self.zeros()
        cnt = 0
        for word in words:
            v = self.embedding.get(word)
            if v is None:
                v = self._fix_word(word)
            if v is not None:
                ret += v
                cnt += 1
        if cnt > 0:
            return ret / cnt
        else:
            return self.zeros()

    def zeros(self):
        return torch.zeros(self.dimension)


def getnode(x):
    return wn.synset_from_pos_and_offset('n', int(x[1:]))


def getwnid(u):
    s = str(u.offset())
    return 'n' + (8 - len(s)) * '0' + s


def getedges(s):
    dic = {x: i for i, x in enumerate(s)}
    edges = []
    for i, u in enumerate(s):
        for v in u.hypernyms():
            j = dic.get(v)
            if j is not None:
                edges.append((i, j))
    return edges


def induce_parents(s, stop_set):
    q = s
    vis = set(s)
    l = 0
    while l < len(q):
        u = q[l]
        l += 1
        if u in stop_set:
            continue
        for p in u.hypernyms():
            if p not in vis:
                vis.add(p)
                q.append(p)


def make_induced_graph(word_file, output_file, glove_path):
    print('making graph ...')

    with open(word_file) as fp:
        lines = fp.readlines()

    xml_wnids = []
    for line in lines:
        line = line.strip()
        xml_wnids.append(line.split()[0])

    xml_nodes = list(map(getnode, xml_wnids))
    xml_set = set(xml_nodes)

    key_wnids = [wnid for wnid in xml_wnids]

    s = list(map(getnode, key_wnids))

    induce_parents(s, xml_set)

    s_set = set(s)
    for u in xml_nodes:
        if u not in s_set:
            s.append(u)

    wnids = list(map(getwnid, s))
    edges = getedges(s)

    print('making glove embedding ...')

    glove = GloVe(glove_path)
    vectors = []
    for wnid in wnids:
        vectors.append(glove[getnode(wnid).lemma_names()])
    vectors = torch.stack(vectors)

    print('dumping ...')

    obj = {}
    obj['wnids'] = wnids
    obj['vectors'] = vectors.tolist()
    obj['edges'] = edges
    json.dump(obj, open(output_file, 'w'))


def make_dense_graph(input_file, output_file):
    js = json.load(open(input_file, 'r'))
    wnids = js['wnids']
    vectors = js['vectors']
    edges = js['edges']

    n = len(wnids)
    adjs = {}
    for i in range(n):
        adjs[i] = []
    for u, v in edges:
        adjs[u].append(v)

    new_edges = [[] for i in range(99)]

    for u, wnid in enumerate(wnids):
        q = [u]
        l = 0
        d = {}
        d[u] = 0
        while l < len(q):
            x = q[l]
            l += 1
            for y in adjs[x]:
                if d.get(y) is None:
                    d[y] = d[x] + 1
                    q.append(y)
        for x, dis in d.items():
            new_edges[dis].append((u, x))

    while new_edges[-1] == []:
        new_edges.pop()

    json.dump({'wnids': wnids, 'vectors': vectors, 'edges_set': new_edges},
            open(output_file, 'w'))

