"""
Handlers for various image datasets.
"""

import os

import torchvision.datasets


def fetch_dataset(dataset, root=None, test_split=None):

    dataset = dataset.lower().replace("-", "")

    if dataset == "caltech101":
        if not root:
            root = "~/Datasets"
        if test_split is None:
            test_split = "test"
        dataset_train = torchvision.datasets.ImageFolder(
            os.path.join(root, dataset, "train")
        )
        dataset_eval = torchvision.datasets.ImageFolder(
            os.path.join(root, dataset, test_split)
        )
        num_classes = 101
        img_size = None

    elif dataset == "cifar10":
        if not root:
            root = "~/Datasets"
        dataset_train = torchvision.datasets.CIFAR10(
            os.path.join(root, dataset),
            train=True,
        )
        dataset_eval = torchvision.datasets.CIFAR10(
            os.path.join(root, dataset),
            train=False,
        )
        num_classes = 10
        img_size = 32

    elif dataset == "cifar100":
        if not root:
            root = "~/Datasets"
        dataset_train = torchvision.datasets.CIFAR100(
            os.path.join(root, dataset),
            train=True,
        )
        dataset_eval = torchvision.datasets.CIFAR100(
            os.path.join(root, dataset),
            train=False,
        )
        num_classes = 100
        img_size = 32

    elif dataset == "fruits360":
        if not root:
            root = "~/Datasets"
        dataset_train = torchvision.datasets.ImageFolder(
            os.path.join(root, dataset, "100px", "train")
        )
        dataset_eval = torchvision.datasets.ImageFolder(
            os.path.join(root, dataset, "100px", "test")
        )
        num_classes = 131
        img_size = 100

    elif dataset == "oxfordflowers102":
        if not root:
            root = "~/Datasets"
        dataset_train = torchvision.datasets.ImageFolder(
            os.path.join(root, dataset, "train")
        )
        dataset_eval = torchvision.datasets.ImageFolder(
            os.path.join(root, dataset, "val")
        )
        num_classes = 102
        img_size = None

    elif dataset == "stanfordcars":
        if not root:
            root = "~/Datasets"
        dataset_train = torchvision.datasets.ImageFolder(
            os.path.join(root, dataset, "train")
        )
        dataset_eval = torchvision.datasets.ImageFolder(
            os.path.join(root, dataset, "test")
        )
        num_classes = 196
        img_size = None

    elif dataset == "stl10":
        if not root:
            root = "~/Datasets"
        dataset_train = torchvision.datasets.STL10(
            root=os.path.join(root, dataset),
            split="train",
        )
        dataset_eval = torchvision.datasets.STL10(
            root=os.path.join(root, dataset),
            split="test",
        )
        num_classes = 10
        img_size = 96

    elif dataset == "svhn":
        if not root:
            root = "~/Datasets"
        dataset_train = torchvision.datasets.SVHN(
            os.path.join(root, dataset),
            split="train",
        )
        dataset_eval = torchvision.datasets.SVHN(
            os.path.join(root, dataset),
            split="test",
        )
        num_classes = 10
        img_size = 32

    else:
        raise ValueError("Unrecognised dataset: {}".format(dataset))

    return dataset_train, dataset_eval, num_classes, img_size
