import data
from config import opt

dataset_dict = {
    'cifar10': data.CIFAR10,
    'cifar100': data.CIFAR100,
    'TinyImageNet': data.TinyImageNet,
    'noise': data.Noise,
    'svhn': data.SVHN,
    'ImageNet': data.ImageNet,
}

def get_dataset_by_name(dataset_name, input_size, dataset_dict=dataset_dict):
    dataset_parse = dataset_name.split('-')
    dataset_name = dataset_parse[0]
    partition = dataset_parse[1] if len(dataset_parse)>1 else None
    dataset = dataset_dict[dataset_name](input_size=input_size, partition=partition)
    
    return dataset