from torch_geometric.graphgym.register import register_loader
from custom_graphgym.datasets.citations_dataset import CitationsDataset

@register_loader('citations')
def load_dataset_citations(format, name, dataset_dir):
	if name == "citations":
		dataset_dir = f'{dataset_dir}/{name}'
		return CitationsDataset(dataset_dir)

