import networkx as nx
import torch
from ogb.linkproppred import PygLinkPropPredDataset
import torch_geometric.transforms as T
import random
import numpy as np
from fluidc import asyn_fluidc

d_names = ['ogbl-ppa', 'ogbl-collab', 'ogbl-ddi', 'ogbl-citation2']
d_name = d_names[2]
torch.manual_seed(123)
np.random.seed(123)
random.seed(123)
device = 'cpu'
num_local = 8
dataset = PygLinkPropPredDataset(name=d_name,
                                 transform=T.ToSparseTensor())
data = dataset[0]
adj_t = data.adj_t.to(device)

data = data.to(device)

adj = adj_t.to_scipy()
net = nx.from_scipy_sparse_matrix(adj)
nets = sorted(nx.connected_components(net), key=len, reverse=True)
if len(nets) > 1:
    lgst_net = nx.subgraph(net, nets[0])
    com = asyn_fluidc(lgst_net, k=num_local)
else:
    com = asyn_fluidc(net, k=num_local)
anchor_nodes = []
for i, c in enumerate(com):
    subnet = nx.subgraph(net, c)
    anchor_node = sorted(subnet.degree(), key=lambda x : x[1], reverse=True)[0][0]
    anchor_nodes.append(anchor_node)
pos_enc = []
for node in range(data.num_nodes):
    temp = []
    if node % 1000 == 0:
        print (f"{node} node completed!!")
    for anchor_node in anchor_nodes:
        try:
            temp.append(len(nx.algorithms.shortest_path(net, node, anchor_node)))
        except:
            temp.append(0)
    pos_enc.append(np.array(temp))
pos_enc = np.array(pos_enc)
max_d = 0
for p in pos_enc:
    max_d = max(max_d, max(p))

np.where(pos_enc == 0, max_d, pos_enc)
for i, a in enumerate(anchor_nodes):
    pos_enc[a][i] = 0
torch.save(anchor_nodes, 'ddi_anchor_async_fluid.pt')
torch.save(pos_enc, 'ddi_pe_async_fluid.pt')
torch.save(com, 'ddi_com_async_fluid.pt')
