from networkx.algorithms.traversal.breadth_first_search import bfs_predecessors
import torch
from collections import defaultdict, OrderedDict
import numba
import numpy as np
import networkx as nx
import os, time
from multiprocessing import Pool
from functools import partial

def _group_by(keys, values) -> dict:
    """Group values by keys.

    :param keys: list of keys
    :param values: list of values
    A key value pair i is defined by (key_list[i], value_list[i]).
    :return: OrderedDict where key value pairs have been grouped by key.

     """
    result = defaultdict(list)
    for key, value in zip(keys.tolist(), values.tolist()):
        result[tuple(key)].append(value)
    for key, value in result.items():
        result[key] = torch.IntTensor(sorted(value))
    return OrderedDict(result)


def index_KvsAll(dataset: "Dataset", split: str, key: str):
    """Return an index for the triples in split (''train'', ''valid'', ''test'')
    from the specified key (''sp'' or ''po'' or ''so'') to the indexes of the
    remaining constituent (''o'' or ''s'' or ''p'' , respectively.)

    The index maps from `tuple' to `torch.LongTensor`.

    The index is cached in the provided dataset under name `{split}_sp_to_o` or
    `{split}_po_to_s`, or `{split}_so_to_p`. If this index is already present, does not
    recompute it.

    """
    value = None
    if key == "sp":
        key_cols = [0, 1]
        value_column = 2
        value = "o"
    elif key == "po":
        key_cols = [1, 2]
        value_column = 0
        value = "s"
    elif key == "so":
        key_cols = [0, 2]
        value_column = 1
        value = "p"
    else:
        raise ValueError()

    name = split + "_" + key + "_to_" + value
    if not dataset._indexes.get(name):
        triples = dataset.split(split)
        dataset._indexes[name] = _group_by(
            triples[:, key_cols], triples[:, value_column]
        )

    dataset.config.log(
        "{} distinct {} pairs in {}".format(len(dataset._indexes[name]), key, split),
        prefix="  ",
    )

    return dataset._indexes.get(name)


def index_KvsAll_to_torch(index):
    """Convert `index_KvsAll` indexes to pytorch tensors.

    Returns an nx2 keys tensor (rows = keys), an offset vector
    (row = starting offset in values for corresponding key),
    a values vector (entries correspond to values of original
    index)

    Afterwards, it holds:
        index[keys[i]] = values[offsets[i]:offsets[i+1]]
    """
    keys = torch.tensor(list(index.keys()), dtype=torch.int)
    values = torch.cat(list(index.values()))
    offsets = torch.cumsum(
        torch.tensor([0] + list(map(len, index.values())), dtype=torch.int), 0
    )
    return keys, values, offsets


def index_neighbor(dataset):
    name = "neighbor"
    if not dataset._indexes.get(name):
        train_triples = dataset.split('train')
        G = nx.DiGraph()
        for tri in train_triples:
            s, p, o = tri.tolist()
            G.add_node(s)
            G.add_node(o)
            G.add_edge(s, o, type=p)
        max_neighbor_num = 300
        all_neighbor = torch.zeros((dataset.num_entities(), 2, max_neighbor_num), dtype=torch.long)
        all_neighbor_num = torch.zeros(dataset.num_entities(), dtype=torch.long)
        rng = np.random.default_rng()
        for s in range(dataset.num_entities()):
            if s not in G:
                continue
            suc = list(G.successors(s))
            pre = list(G.predecessors(s))
            suc_edge_types = [G.get_edge_data(s, v)['type'] + dataset.num_relations() for v in suc]
            pre_edge_types = [G.get_edge_data(v, s)['type'] for v in pre]
            rand_permut = rng.permutation(len(suc) + len(pre))
            neighbor = np.asarray(suc + pre)[rand_permut]
            neighbor_edge_types = np.asarray(suc_edge_types + pre_edge_types)[rand_permut]
            neighbor = neighbor[:max_neighbor_num]
            neighbor_edge_types = neighbor_edge_types[:max_neighbor_num]
            all_neighbor[s, 0, 0:len(neighbor)] = torch.tensor(neighbor, dtype=torch.long)
            all_neighbor[s, 1, 0:len(neighbor)] = torch.tensor(neighbor_edge_types, dtype=torch.long)
            all_neighbor_num[s] = len(neighbor)
        dataset._indexes[name] = (all_neighbor, all_neighbor_num)

    dataset.config.log("One-hop Neighbors index finished", prefix="  ")

    return dataset._indexes.get(name)


def process_one_semantic_neighbor(t, neighbors, mask, max_hop, max_neighbor_num, G):
    neighbors = neighbors[mask] # filter out those with low similarity
    neighbor_num = len(neighbors) # number of neighbor for t
    
    # reversly iterate over the bfs tree;
    # find the path from the leaf node to the root
    # the output path is acutally in a form of : root --> ... --> leaf, which shows better performance 
    def iter_bfs_rev(neighbor, bfs_dict, path_dict, length_dict):

        child = neighbor
        path = []
        while bfs_dict.get(child) is not None:
            # find the predecessor
            parent = bfs_dict.get(child)
            path = [G.get_edge_data(parent, child)['type'], ] + path
            child = parent

        # exit if this neighbor is too far from the root node or reach the root node
        if len(path)>0:
            path_dict[neighbor] = path
            length_dict[neighbor] = len(path)

        return

    path_dict, length_dict = dict([]), dict([])
    if neighbor_num>0:
        bfs_dict = dict(nx.bfs_predecessors(G, t, depth_limit=max_hop))
        for neighbor in neighbors:
            iter_bfs_rev(neighbor, bfs_dict, path_dict, length_dict)
    else:
        bfs_dict = dict({})

    rel_chains = torch.zeros([neighbor_num, max_hop], dtype=torch.long) # the relation chain for each neighbor
    rel_lengths = torch.zeros(neighbor_num, dtype=torch.long) # the length of each chain
    for i, n in enumerate(neighbors):
        if n == t:
            continue

        chain, length = path_dict.get(n), length_dict.get(n)
        if length:
            # print(chain, length)
            rel_chains[i, :length] = torch.tensor(chain, dtype=torch.long)
            rel_lengths[i] = length
        else:
            rel_lengths[i] = -1

    
    neighbors = torch.tensor(neighbors, dtype=torch.long)
    # print (neighbors)
    return (t, neighbors, neighbor_num, rel_chains, rel_lengths)


def process_one_graph_neighbor(s, dataset, max_neighbor_num, G):
    
    suc = list(G.successors(s))
    pre = list(G.predecessors(s))
    suc_edge_types = [G.get_edge_data(s, v)['type'] + dataset.num_relations() for v in suc]
    pre_edge_types = [G.get_edge_data(v, s)['type'] for v in pre]
    rand_permut = np.random.permutation(len(suc) + len(pre))

    neighbors = np.asarray(suc + pre)[rand_permut]
    relations = np.asarray(suc_edge_types + pre_edge_types)[rand_permut]
    neighbors = neighbors[:max_neighbor_num]
    relations = relations[:max_neighbor_num]
    
    
    neighbor_num = len(neighbors) # number of neighbor for t

    rel_chains = torch.tensor(relations.reshape([neighbor_num, 1]), dtype=torch.long) # the relation chain for each neighbor
    rel_lengths = torch.ones(neighbor_num, dtype=torch.long) # the length of each chain
    
    neighbors = torch.tensor(neighbors, dtype=torch.long)
    return (s, neighbors, neighbor_num, rel_chains, rel_lengths)



def extract_subgraph_information(dataset, G, neighbors=None, max_neighbor_num=300, max_hop=5, mask=None, padding_size=None):
    if padding_size is None:
        padding_size= max_neighbor_num
    all_neighbor = torch.zeros((dataset.num_entities(), padding_size), dtype=torch.long)
    all_neighbor_num = torch.zeros(dataset.num_entities(), dtype=torch.long)
    all_rel_chain = torch.zeros((dataset.num_entities(), padding_size, max_hop), dtype=torch.long)
    all_rel_chain_num = torch.zeros(dataset.num_entities(), padding_size, dtype=torch.long)
    
    
    all_nodes = [t for t in range(dataset.num_entities()) if t in G]

    results = []
    start = time.time()
    if neighbors is not None:
        for node in all_nodes:
            results.append(process_one_semantic_neighbor(node, neighbors[node], mask[node], max_hop, max_neighbor_num, G))
            # dataset.config.log("processing time:%f, %i" % (time.time()-start, node), prefix="  ")
    else:
        for node in all_nodes:
            # only one-hop graph neighbors are considerred
            results.append(process_one_graph_neighbor(node, dataset, max_neighbor_num, G))
    print(results[0])
    dataset.config.log("processing time:%f" % (time.time()-start,), prefix="  ")

    for (t, sources, neighbor_num, rel_chains, lengths) in results:
        all_neighbor[t, :neighbor_num] = sources
        all_neighbor_num[t] = neighbor_num
        all_rel_chain[t, :neighbor_num] = rel_chains # neighbor_num X max_hop
        all_rel_chain_num[t, :neighbor_num] = lengths
        # all_neighbor: num_ent X max_neighbor_num
        # all_neighbor_num: num_ent
        # all_rel_chain: num_ent X max_neighbor_num X max_hop
        # all_rel_chain_num: num_ent X max_neighbor_num
        # print(all_rel_chain.shape)

    
    # print(all_rel_chain.shape)
    return (all_neighbor, all_neighbor_num, all_rel_chain, all_rel_chain_num)

def index_graph_neighbor(dataset):
    name = "graph_neighbor"
    
    if not dataset._indexes.get(name):
        train_triples = dataset.split('train')
        G = nx.DiGraph()
        for tri in train_triples:
            s, p, o = tri.tolist()
            G.add_node(s)
            G.add_node(o)
            G.add_edge(s, o, type=p)
            # G.add_edge(o, s, type=p+dataset.num_relations())

        max_neighbor_num = 300
        max_hop = 1

        all_neighbor = torch.zeros((dataset.num_entities(), max_neighbor_num), dtype=torch.long)
        all_neighbor_num = torch.zeros(dataset.num_entities(), dtype=torch.long)
        all_rel_chain = torch.zeros((dataset.num_entities(), max_neighbor_num, max_hop), dtype=torch.long)
        all_rel_chain_num = torch.zeros(dataset.num_entities(), max_neighbor_num, dtype=torch.long)


        # all_neighbor = torch.zeros((dataset.num_entities(), 2, max_neighbor_num), dtype=torch.long)
        # all_neighbor_num = torch.zeros(dataset.num_entities(), dtype=torch.long)
        rng = np.random.default_rng()
        for s in range(dataset.num_entities()):
            if s not in G:
                continue
            suc = list(G.successors(s))
            pre = list(G.predecessors(s))
            suc_edge_types = [G.get_edge_data(s, v)['type'] + dataset.num_relations() for v in suc]
            pre_edge_types = [G.get_edge_data(v, s)['type'] for v in pre]
            rand_permut = rng.permutation(len(suc) + len(pre))
            neighbor = np.asarray(suc + pre)[rand_permut]
            neighbor_edge_types = np.asarray(suc_edge_types + pre_edge_types)[rand_permut]
            neighbor = neighbor[:max_neighbor_num]
            neighbor_edge_types = neighbor_edge_types[:max_neighbor_num]


            all_neighbor[s, 0:len(neighbor)] = torch.tensor(neighbor, dtype=torch.long)
            all_neighbor_num[s] = len(neighbor)
            all_rel_chain[s, 0:len(neighbor), 0] = torch.tensor(neighbor_edge_types, dtype=torch.long)
            all_rel_chain_num[s, 0:len(neighbor)] = torch.ones(len(neighbor), dtype=torch.long)

        dataset._indexes[name] = (all_neighbor, all_neighbor_num, all_rel_chain, all_rel_chain_num)

    dataset.config.log("Graph Neighbors index finished", prefix="  ")
    # revise

    return dataset._indexes.get(name)





def index_relation_types(dataset):
    """Classify relations into 1-N, M-1, 1-1, M-N.

    According to Bordes et al. "Translating embeddings for modeling multi-relational
    data.", NIPS13.

    Adds index `relation_types` with list that maps relation index to ("1-N", "M-1",
    "1-1", "M-N").

    """
    if "relation_types" not in dataset._indexes:
        # 2nd dim: num_s, num_distinct_po, num_o, num_distinct_so, is_M, is_N
        relation_stats = torch.zeros((dataset.num_relations(), 6))
        for index, p in [
            (dataset.index("train_sp_to_o"), 1),
            (dataset.index("train_po_to_s"), 0),
        ]:
            for prefix, labels in index.items():
                relation_stats[prefix[p], 0 + p * 2] = relation_stats[
                    prefix[p], 0 + p * 2
                ] + len(labels)
                relation_stats[prefix[p], 1 + p * 2] = (
                    relation_stats[prefix[p], 1 + p * 2] + 1.0
                )
        relation_stats[:, 4] = (relation_stats[:, 0] / relation_stats[:, 1]) > 1.5
        relation_stats[:, 5] = (relation_stats[:, 2] / relation_stats[:, 3]) > 1.5
        relation_types = []
        for i in range(dataset.num_relations()):
            relation_types.append(
                "{}-{}".format(
                    "1" if relation_stats[i, 4].item() == 0 else "M",
                    "1" if relation_stats[i, 5].item() == 0 else "N",
                )
            )

        dataset._indexes["relation_types"] = relation_types
    return dataset._indexes["relation_types"]


def index_relations_per_type(dataset):
    if "relations_per_type" not in dataset._indexes:
        relations_per_type = {}
        for i, k in enumerate(dataset.index("relation_types")):
            relations_per_type.setdefault(k, set()).add(i)
        dataset._indexes["relations_per_type"] = relations_per_type
    else:
        relations_per_type = dataset._indexes["relations_per_type"]

    dataset.config.log("Loaded relation index")
    for k, relations in relations_per_type.items():
        dataset.config.log(
            "{} relations of type {}".format(len(relations), k), prefix="  "
        )

    return relations_per_type


def index_frequency_percentiles(dataset, recompute=False):
    """
    :return: dictionary mapping from
    {
        'subject':
        {25%, 50%, 75%, top} -> set of entities
        'relations':
        {25%, 50%, 75%, top} -> set of relations
        'object':
        {25%, 50%, 75%, top} -> set of entities
    }
    """
    if "frequency_percentiles" in dataset._indexes and not recompute:
        return
    subject_stats = torch.zeros((dataset.num_entities(), 1))
    relation_stats = torch.zeros((dataset.num_relations(), 1))
    object_stats = torch.zeros((dataset.num_entities(), 1))
    for (s, p, o) in dataset.split("train"):
        subject_stats[s] += 1
        relation_stats[p] += 1
        object_stats[o] += 1
    result = dict()
    for arg, stats, num in [
        (
            "subject",
            [
                i
                for i, j in list(
                    sorted(enumerate(subject_stats.tolist()), key=lambda x: x[1])
                )
            ],
            dataset.num_entities(),
        ),
        (
            "relation",
            [
                i
                for i, j in list(
                    sorted(enumerate(relation_stats.tolist()), key=lambda x: x[1])
                )
            ],
            dataset.num_relations(),
        ),
        (
            "object",
            [
                i
                for i, j in list(
                    sorted(enumerate(object_stats.tolist()), key=lambda x: x[1])
                )
            ],
            dataset.num_entities(),
        ),
    ]:
        for percentile, (begin, end) in [
            ("25%", (0.0, 0.25)),
            ("50%", (0.25, 0.5)),
            ("75%", (0.5, 0.75)),
            ("top", (0.75, 1.0)),
        ]:
            if arg not in result:
                result[arg] = dict()
            result[arg][percentile] = set(stats[int(begin * num) : int(end * num)])
    dataset._indexes["frequency_percentiles"] = result


class IndexWrapper:
    """Wraps a call to an index function so that it can be pickled"""

    def __init__(self, fun, **kwargs):
        self.fun = fun
        self.kwargs = kwargs

    def __call__(self, dataset: "Dataset", **kwargs):
        self.fun(dataset, **self.kwargs)


def _invert_ids(dataset, obj: str):
    if not f"{obj}_id_to_index" in dataset._indexes:
        ids = dataset.load_map(f"{obj}_ids")
        inv = {v: k for k, v in enumerate(ids)}
        dataset._indexes[f"{obj}_id_to_index"] = inv
    else:
        inv = dataset._indexes[f"{obj}_id_to_index"]
    dataset.config.log(f"Indexed {len(inv)} {obj} ids", prefix="  ")


def create_default_index_functions(dataset: "Dataset"):
    for split in dataset.files_of_type("triples"):
        for key, value in [("sp", "o"), ("po", "s"), ("so", "p")]:
            # self assignment needed to capture the loop var
            dataset.index_functions[f"{split}_{key}_to_{value}"] = IndexWrapper(
                index_KvsAll, split=split, key=key
            )
    dataset.index_functions["neighbor"] = index_neighbor
    dataset.index_functions["relation_types"] = index_relation_types
    dataset.index_functions["relations_per_type"] = index_relations_per_type
    dataset.index_functions["frequency_percentiles"] = index_frequency_percentiles

    dataset.index_functions["graph_neighbor"] = index_graph_neighbor

    for obj in ["entity", "relation"]:
        dataset.index_functions[f"{obj}_id_to_index"] = IndexWrapper(
            _invert_ids, obj=obj
        )


@numba.njit
def where_in(x, y, not_in=False):
    """Retrieve the indices of the elements in x which are also in y.

    x and y are assumed to be 1 dimensional arrays.

    :params: not_in: if True, returns the indices of the of the elements in x
    which are not in y.

    """
    # np.isin is not supported in numba. Also: "i in y" raises an error in numba
    # setting njit(parallel=True) slows down the function
    list_y = set(y)
    return np.where(np.array([i in list_y for i in x]) != not_in)[0]
