from data.rico import Rico
from data.publaynet import PubLayNet


def get_dataset(name, split, transform=None, data_path=None):
    if name == 'rico':
        return Rico(split, transform, data_path=data_path)

    elif name == 'publaynet':
        return PubLayNet(split, transform)

    elif name == 'magazine':
        return Magazine(split, transform)

    raise NotImplementedError(name)
