from .ImageNet import ImageNet
from .OxfordIIITPet import OxfordIIITPet
from .EuroSAT import EuroSAT
from .Caltech101 import Caltech101
from .Food101 import Food101
from .CIFAR import CIFAR10, CIFAR100
from .CUB200 import CUB200
from .multiDatasets import multiDatasets
from .Flowers102 import Flowers102
from .TinyImageNet import TinyImageNet
from .Imagenet_R import Imagenet_R
from .StanfordCars import StanfordCars
from .FGVCAircraft import FGVCAircraft
from .SUN397 import SUN397

__all__ = [
    "CUB200",
    "multiDatasets",
    "Flowers102",
    "TinyImageNet",
    "CIFAR10",
    "CIFAR100",
    "ImageNet",
    "Imagenet_R",
    "Food101",
    "Caltech101",
    "StanfordCars",
    "EuroSAT",
    "OxfordIIITPet",
    "FGVCAircraft",
    "SUN397",
]

# dictionary of tuple of dataset, mean, std
datasets = {
    "cifar10": (CIFAR10, 10),
    "cifar100": (CIFAR100, 100),
    "tinyimagenet": (TinyImageNet, 200),
    "cub200": (CUB200, 200),
    "imagenet": (ImageNet, 1000),
    "imagenet-r": (Imagenet_R, 200),
    "food101": (Food101, 101),
    "caltech101": (Caltech101, 100),
    "eurosat": (EuroSAT, 10),
    "flowers102": (Flowers102, 102),
    "oxford_pet": (OxfordIIITPet, 37),
    "stanfordcars": (StanfordCars, 196),
    "aircraft": (FGVCAircraft, 100),
    "sun397": (SUN397, 397)
}


def get_dataset(name):
    name = name.split(',')
    if len(name) == 1:
        return datasets[name[0]]
    else:
        return list(zip(*[datasets[n] for n in name]))
