from torch_geometric.loader import DataLoader
from ogb.graphproppred import PygGraphPropPredDataset
import torch





def molfreesolv():
    dataset = PygGraphPropPredDataset(name="ogbg-molfreesolv", root=r'./dataset')
    split_idx = dataset.get_idx_split()
    valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False)
    test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False)
    train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True)

    return train_loader, valid_loader, test_loader



if __name__ == '__main__':
    molfreesolv()