from torchvision.datasets import CIFAR10, CIFAR100, SVHN
from datasets import cifar, tinyimagenet

registered_datasets = [cifar, tinyimagenet]

def get_dataset_from_config(config):

    for registered_dataset in registered_datasets:
        if registered_dataset.is_valid_dataset_name(config.dataset_name):
            return registered_dataset.get_dataset_from_config(config)
    raise NotImplementedError(f'Dataset {config.dataset_name} is not supported')