from .cifar10 import CIFAR10
from .cifar100 import CIFAR100
from .stl10 import STL10
from .tiny_in import TinyImageNet
from .imagenet import ImageNet


DS_LIST = ["cifar10", "cifar100", "stl10", "tinyimagenet", "imagenet"]


def get_ds(name):
    assert name in DS_LIST
    if name == "cifar10":
        return CIFAR10
    elif name == "cifar100":
        return CIFAR100
    elif name == "stl10":
        return STL10
    elif name == "tinyimagenet":
        return TinyImageNet
    elif name == "imagenet":
        return ImageNet
