import argparse
import os.path as osp

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from superpixels import SPDataset




def load_dataset(name, args):
    path = 'data'
    if name.lower() in ['mnist', 'cifar10']:
        dataset = SPDataset(osp.join(path, 'superpixels'), name=name).shuffle()
    else:
        dataset = TUDataset(path, name=name).shuffle()
    return dataset


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='MUTAG', type=str)
    parser.add_argument('--batch_size', default=10, type=int)
    args = parser.parse_args()

    dataset = load_dataset(args.dataset, args)
    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    print(dataset)
    print(dataset[0])
    for data in data_loader:
        print(data)
        break


    # Data(edge_index=[2, 50], x=[22, 7], edge_attr=[50, 4], y=[1])
    # DataBatch(edge_index=[2, 470], x=[207, 7], edge_attr=[470, 4], y=[10], batch=[207], ptr=[11])

