import multiprocessing as mp
import random
from multiprocessing import get_context

import networkx as nx
import numpy as np
import torch
from tqdm.auto import tqdm


def get_communities(remove_feature):
    community_size = 20

    # Create 20 cliques (communities) of size 20,
    # then rewire a single edge in each clique to a node in an adjacent clique
    graph = nx.connected_caveman_graph(20, community_size)

    # Randomly rewire 1% edges
    node_list = list(graph.nodes)
    for (u, v) in graph.edges():
        if random.random() < 0.01:
            x = random.choice(node_list)
            if graph.has_edge(u, x):
                continue
            graph.remove_edge(u, v)
            graph.add_edge(u, x)

    # remove self-loops
    graph.remove_edges_from(nx.selfloop_edges(graph))
    edge_index = np.array(list(graph.edges))
    # Add (i, j) for an edge (j, i)
    edge_index = np.concatenate((edge_index, edge_index[:, ::-1]), axis=0)
    edge_index = torch.from_numpy(edge_index).long().permute(1, 0)

    n = graph.number_of_nodes()
    label = np.zeros((n, n), dtype=int)
    for u in node_list:
        # the node IDs are simply consecutive integers from 0
        for v in range(u):
            if u // community_size == v // community_size:
                label[u, v] = 1

    if remove_feature:
        feature = torch.ones((n, 1))
    else:
        rand_order = np.random.permutation(n)
        feature = np.identity(n)[:, rand_order]

    data = {
        "edge_index": edge_index,
        "feature": feature,
        "positive_edges": np.stack(np.nonzero(label)),
        "num_nodes": feature.shape[0],
    }

    return data


def to_single_directed(edges):
    edges_new = np.zeros((2, edges.shape[1] // 2), dtype=int)
    j = 0
    for i in range(edges.shape[1]):
        if edges[0, i] < edges[1, i]:
            edges_new[:, j] = edges[:, i]
            j += 1

    return edges_new


# each node at least remain in the new graph
def split_edges(p, edges, data, non_train_ratio=0.2):
    e = edges.shape[1]
    edges = edges[:, np.random.permutation(e)]
    split1 = int((1 - non_train_ratio) * e)
    split2 = int((1 - non_train_ratio / 2) * e)

    data.update(
        {
            "{}_edges_train".format(p): edges[:, :split1],  # 80%
            "{}_edges_val".format(p): edges[:, split1:split2],  # 10%
            "{}_edges_test".format(p): edges[:, split2:],  # 10%
        }
    )


def to_bidirected(edges):
    return np.concatenate((edges, edges[::-1, :]), axis=-1)


def get_negative_edges(positive_edges, num_nodes, num_negative_edges):
    positive_edge_set = []
    positive_edges = to_bidirected(positive_edges)
    for i in range(positive_edges.shape[1]):
        positive_edge_set.append(tuple(positive_edges[:, i]))
    positive_edge_set = set(positive_edge_set)

    negative_edges = np.zeros(
        (2, num_negative_edges), dtype=positive_edges.dtype
    )
    for i in range(num_negative_edges):
        while True:
            mask_temp = tuple(
                np.random.choice(num_nodes, size=(2,), replace=False)
            )
            if mask_temp not in positive_edge_set:
                negative_edges[:, i] = mask_temp
                break

    return negative_edges


def get_pos_neg_edges(data, infer_link_positive=True):
    if infer_link_positive:
        data["positive_edges"] = to_single_directed(data["edge_index"].numpy())
    split_edges("positive", data["positive_edges"], data)

    # resample edge mask link negative
    negative_edges = get_negative_edges(
        data["positive_edges"],
        data["num_nodes"],
        num_negative_edges=data["positive_edges"].shape[1],
    )
    split_edges("negative", negative_edges, data)

    return data


def shortest_path(graph, node_range, cutoff):
    dists_dict = {}
    for node in tqdm(node_range, leave=False):
        dists_dict[node] = nx.single_source_shortest_path_length(
            graph, node, cutoff
        )
    return dists_dict


def merge_dicts(dicts):
    result = {}
    for dictionary in dicts:
        result.update(dictionary)
    return result


def all_pairs_shortest_path(graph, cutoff=None, num_workers=4):
    nodes = list(graph.nodes)
    random.shuffle(nodes)
    pool = mp.Pool(processes=num_workers)
    interval_size = len(nodes) / num_workers
    results = [
        pool.apply_async(
            shortest_path,
            args=(
                graph,
                nodes[int(interval_size * i) : int(interval_size * (i + 1))],
                cutoff,
            ),
        )
        for i in range(num_workers)
    ]
    output = [p.get() for p in results]
    dists_dict = merge_dicts(output)
    pool.close()
    pool.join()
    return dists_dict


def precompute_dist_data(edge_index, num_nodes, approximate=0):
    """
    Here dist is 1/real_dist, higher actually means closer, 0 means disconnected
    :return:
    """
    graph = nx.Graph()
    edge_list = edge_index.transpose(1, 0).tolist()
    graph.add_edges_from(edge_list)

    n = num_nodes
    dists_array = np.zeros((n, n))
    dists_dict = all_pairs_shortest_path(
        graph, cutoff=approximate if approximate > 0 else None
    )
    node_list = graph.nodes()
    for node_i in node_list:
        shortest_dist = dists_dict[node_i]
        for node_j in node_list:
            dist = shortest_dist.get(node_j, -1)
            if dist != -1:
                dists_array[node_i, node_j] = 1 / (dist + 1)
    return dists_array


def get_dataset(args):
    # Generate graph data
    data_info = get_communities(args.inductive)
    # Get positive and negative edges
    data = get_pos_neg_edges(
        data_info, infer_link_positive=True if args.task == "link" else False
    )
    # Pre-compute shortest path length
    if args.task == "link":
        dists_removed = precompute_dist_data(
            data["positive_edges_train"],
            data["num_nodes"],
            approximate=args.k_hop_dist,
        )
        data["dists"] = torch.from_numpy(dists_removed).float()
        data["edge_index"] = torch.from_numpy(
            to_bidirected(data["positive_edges_train"])
        ).long()
    else:
        dists = precompute_dist_data(
            data["edge_index"].numpy(),
            data["num_nodes"],
            approximate=args.k_hop_dist,
        )
        data["dists"] = torch.from_numpy(dists).float()

    return data


def get_anchors(n):
    """Get a list of NumPy arrays, each of them is an anchor node set"""
    m = int(np.log2(n))
    anchor_set_id = []
    for i in range(m):
        anchor_size = int(n / np.exp2(i + 1))
        for _ in range(m):
            anchor_set_id.append(
                np.random.choice(n, size=anchor_size, replace=False)
            )
    return anchor_set_id


def get_dist_max(anchor_set_id, dist):
    # N x K, N is number of nodes, K is the number of anchor sets
    dist_max = torch.zeros((dist.shape[0], len(anchor_set_id)))
    dist_argmax = torch.zeros((dist.shape[0], len(anchor_set_id))).long()
    for i in range(len(anchor_set_id)):
        temp_id = torch.as_tensor(anchor_set_id[i], dtype=torch.long)
        # Get reciprocal of shortest distance to each node in the i-th anchor set
        dist_temp = torch.index_select(dist, 1, temp_id)
        # For each node in the graph, find its closest anchor node in the set
        # and the reciprocal of shortest distance
        dist_max_temp, dist_argmax_temp = torch.max(dist_temp, dim=-1)
        dist_max[:, i] = dist_max_temp
        dist_argmax[:, i] = torch.index_select(temp_id, 0, dist_argmax_temp)
    return dist_max, dist_argmax


def get_a_graph(dists_max, dists_argmax):
    src = []
    dst = []
    real_src = []
    real_dst = []
    edge_weight = []
    dists_max = dists_max.numpy()
    for i in range(dists_max.shape[0]):
        # Get unique closest anchor nodes for node i across all anchor sets
        tmp_dists_argmax, tmp_dists_argmax_idx = np.unique(
            dists_argmax[i, :], True
        )
        src.extend([i] * tmp_dists_argmax.shape[0])
        real_src.extend([i] * dists_argmax[i, :].shape[0])
        real_dst.extend(list(dists_argmax[i, :].numpy()))
        dst.extend(list(tmp_dists_argmax))
        edge_weight.extend(dists_max[i, tmp_dists_argmax_idx].tolist())
    eid_dict = {(u, v): i for i, (u, v) in enumerate(list(zip(dst, src)))}
    anchor_eid = [eid_dict.get((u, v)) for u, v in zip(real_dst, real_src)]
    g = (dst, src)
    return g, anchor_eid, edge_weight


def get_graphs(data, anchor_sets):
    graphs = []
    anchor_eids = []
    dists_max_list = []
    edge_weights = []
    for anchor_set in tqdm(anchor_sets, leave=False):
        dists_max, dists_argmax = get_dist_max(anchor_set, data["dists"])
        g, anchor_eid, edge_weight = get_a_graph(dists_max, dists_argmax)
        graphs.append(g)
        anchor_eids.append(anchor_eid)
        dists_max_list.append(dists_max)
        edge_weights.append(edge_weight)

    return graphs, anchor_eids, dists_max_list, edge_weights


def merge_result(outputs):
    graphs = []
    anchor_eids = []
    dists_max_list = []
    edge_weights = []

    for g, anchor_eid, dists_max, edge_weight in outputs:
        graphs.extend(g)
        anchor_eids.extend(anchor_eid)
        dists_max_list.extend(dists_max)
        edge_weights.extend(edge_weight)

    return graphs, anchor_eids, dists_max_list, edge_weights


def preselect_anchor(data, args, num_workers=4):
    pool = get_context("spawn").Pool(processes=num_workers)
    # Pre-compute anchor sets, a collection of anchor sets per epoch
    anchor_set_ids = [
        get_anchors(data["num_nodes"]) for _ in range(args.epoch_num)
    ]
    interval_size = len(anchor_set_ids) / num_workers
    results = [
        pool.apply_async(
            get_graphs,
            args=(
                data,
                anchor_set_ids[
                    int(interval_size * i) : int(interval_size * (i + 1))
                ],
            ),
        )
        for i in range(num_workers)
    ]

    output = [p.get() for p in results]
    graphs, anchor_eids, dists_max_list, edge_weights = merge_result(output)
    pool.close()
    pool.join()

    return graphs, anchor_eids, dists_max_list, edge_weights
