from torch_geometric.datasets import TUDataset
from torch_geometric.utils import to_dense_adj
import numpy as np
from property import Graph
import os
import pickle

# names = ['IMDB-BINARY', 'REDDIT-BINARY', 'PROTEINS']
names = ["AIDS", "DHFR", "Mutagenicity", "NCI1", "DD", "ENZYMES", "PROTEINS", "Cuneiform", "MSRC_21"]
for name in names:
    os.makedirs(f'./data/{name}/after', exist_ok=True)

    dataset = TUDataset(root='./data', name=name)

    print(dataset.data)

    def sparse_mat(adj_mat):
        adj_mat = np.array(adj_mat)
        pad_width = ((0, 128 - len(adj_mat)), (0, 128 - len(adj_mat)))
        matrix_128x128 = np.pad(adj_mat, pad_width=pad_width, mode='constant', constant_values=0)
        return matrix_128x128

    result = []
    props = []
    labels = []
    features = []
    for index, data in enumerate(dataset):
        if data.num_nodes <= 128:
            print(index, len(dataset))
            features.append(data.x)
            labels.append(data.y)
            edge_index = data.edge_index
            adj_mat = to_dense_adj(edge_index, max_num_nodes=data.num_nodes)[0]
            graph = Graph(adj_mat)
            props.append(graph.result())
            adj_mat = sparse_mat(adj_mat)
            result.append(adj_mat)

    # features = np.array(features)
    # np.save(f'./data/{name}/after/features.npy', features)

    labels = np.array(labels)
    np.save(f'./data/{name}/after/labels.npy', labels)

    props = np.array(props)            
    np.save(f'./data/{name}/after/props.npy', props)

    result = np.array(result)
    print(result)
    np.save(f'./data/{name}/after/data.npy', result)