from torch_geometric.data import Dataset, InMemoryDataset

from src.datasets.dataset_utils.dataset_splitting import add_set_masks, add_fold_masks, create_sets, create_sets

def test_add_set_masks_dataset():
    ds = Dataset(root='tmp',
                 )