_DATASET_DICT = {}


def register_dataset(name):
    def decorator(cls):
        _DATASET_DICT[name] = cls
        return cls

    return decorator


def get_dataset(config):
    train_set = _DATASET_DICT[config.dataset](split='train', **config.train_data_config)
    test_set = _DATASET_DICT[config.dataset](split='test', **config.test_data_config)
    return train_set, test_set
