import networkx as nx
import torch
from ogb.linkproppred import PygLinkPropPredDataset
import torch_geometric.transforms as T
import random
import numpy as np
from fluidc import asyn_fluidc
import multiprocessing


def get_com():
    torch.manual_seed(123)
    np.random.seed(123)
    random.seed(123)
    dataset = PygLinkPropPredDataset(name='ogbl-citation2',
                                     transform=T.ToSparseTensor())
    data = dataset[0]

    adj = data.adj_t.to_symmetric().to_scipy()
    #adj = data.adj_t.to_scipy()
    net = nx.from_scipy_sparse_matrix(adj)
    nets = sorted(nx.connected_components(net), key=len, reverse=True)
    if len(nets) > 1:
        lgst_net = nx.subgraph(net, nets[0])
        others = []
        for n in nets[1:]:
            others.extend(list(n))
        com = asyn_fluidc(lgst_net, k=15)
    else:
        com = asyn_fluidc(net, k=15)
    torch.save(com, 'citation2_com_async_fluid.pt')
    anchor_nodes = []
    for i, c in enumerate(com):
        subnet = nx.subgraph(net, c)
        anchor_node = sorted(subnet.degree(), key=lambda x: x[1], reverse=True)[0][0]
        anchor_nodes.append(anchor_node)
    torch.save(anchor_nodes, 'citation2_anchor_async_fluid.pt')

    return [i for i in range(data.num_nodes)], net, anchor_nodes

def get_pe(net, anchor_nodes, nodes, map_ret):
    for node in nodes:
        pos_ret = dict()
        pos_enc = []
        for anchor_node in anchor_nodes:
            try:
                pos_enc.append(len(nx.algorithms.shortest_path(net, node, anchor_node)))
            except:
                pos_enc.append(0)
        pos_ret[node] = pos_enc
        map_ret.update(pos_ret)
        print(f"{node} node completed!!")

def main():
    mult_manager = multiprocessing.Manager()
    nodes, net, anchor_nodes = get_com()
    WORKER_NUM = 32
    bs = int(len(nodes) / WORKER_NUM)
    if len(nodes) % WORKER_NUM != 0:
        bs += 1
    vec_process = []
    return_dict = mult_manager.dict()
    for pidx in range(WORKER_NUM):
        p = multiprocessing.Process(target=get_pe, args=(net, anchor_nodes, nodes[pidx * bs: min((pidx + 1) * bs, len(nodes))], return_dict))
        p.start()
        vec_process.append(p)
    for p in vec_process:
        p.join()

    ret = dict()
    ret.update(return_dict.copy())
    torch.save(ret, 'citation2_pe_dict.pt')
    pos_enc = []
    for i in range(len(nodes)):
        pos_enc.append(ret[i])
    pos_enc = np.array(pos_enc)
    max_d = 0
    for p in pos_enc:
        max_d = max(max_d, max(p))

    np.where(pos_enc == 0, max_d, pos_enc)
    for i, a in enumerate(anchor_nodes):
        pos_enc[a][i] = 0
    torch.save(pos_enc, 'citation2_pe_async_fluid.pt')

    return ret

if __name__ == '__main__':
    main()