from torch_geometric.loader import DataLoader
from ogb.graphproppred import PygGraphPropPredDataset


def molpcba():
    dataset = PygGraphPropPredDataset(name="ogbg-molpcba", root=r'dataset')
    split_idx = dataset.get_idx_split()
    valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=512, shuffle=True)
    test_loader = DataLoader(dataset[split_idx["test"]], batch_size=512, shuffle=False)
    train_loader = DataLoader(dataset[split_idx["train"]], batch_size=512, shuffle=True)

    return train_loader, valid_loader, test_loader



if __name__ == '__main__':
    molpcba()