import os

from src.data import load_data
import torch
import dgl

GRAPH_CACHE_DIR = 'cached_graphs'

if __name__ == "__main__":

    if not os.path.isdir(GRAPH_CACHE_DIR): os.makedirs(GRAPH_CACHE_DIR)

    for dataset in ['cora', 'citeseer', 'pubmed']:
        graph = load_data(dataset)
        del graph.ndata['feat'], graph.ndata['label'], graph.ndata['train_mask'], graph.ndata['test_mask'], graph.ndata['val_mask']
        sgs = []
        inverse_indices = []
        degrees = []

        for i in range(0, graph.number_of_nodes()):
            data = dgl.khop_in_subgraph(graph, [i], k=3)
            sgs.append(data[0])
            inverse_indices.append(data[1])
            neighbors = data[0].predecessors(data[1])
            neighbors = neighbors[neighbors != data[1]]
            degrees.append(len(neighbors))
        inverse_indices = torch.cat(inverse_indices)
        degrees = torch.tensor(degrees)
        dgl.data.utils.save_graphs("{}/{}.bin".format(GRAPH_CACHE_DIR, dataset), sgs, {'inverse_indices': inverse_indices, 'degrees': degrees})