#import dgl
#import torch
#from ogb.graphproppred import DglGraphPropPredDataset
#from dgl.dataloading import GraphDataLoader


#def _collate_fn(batch):
    # batch is a list of tuple (graph, label)
#    graphs = [e[0] for e in batch]
#    g = dgl.batch(graphs)
#    labels = [e[1] for e in batch]
#    labels = torch.stack(labels, 0)
#    return g, labels

# load dataset
#dataset = DglGraphPropPredDataset(name='ogbg-molhiv')
#split_idx = dataset.get_idx_split()
# dataloader
#train_loader = GraphDataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=_collate_fn)
#valid_loader = GraphDataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
#test_loader = GraphDataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)

#from torch_geometric.datasets import MoleculeNet
#root = './data'
#name = 'HIV'
#dataset1 = MoleculeNet(root, name)

#print(dataset1[split_idx["train"]])

#import data_utils as topo_data



#dataset_cls = topo_data.get_dataset_class(dataset='IMDB-BINARY')

#dataset = dataset_cls(batch_size=32, use_node_attributes=True, lift_to_simplex=True, max_simplex_dim=2)
#ds = dataset.prepare_data()

#print(ds)

#print(next(iter(dataset.train_dataloader())))
# print(next(iter(dataset.val_dataloader())))
# print(next(iter(dataset.test_dataloader())))

