import torchvision

import dataset


class Svhn(dataset.DatapoolDisk):

    NAME = "svhn"

    def get_classes(self):
        """Return the int number of classes."""
        return 10

    def get_input_dim(self):
        """Return tuple of int dimensions of input."""
        return (3, 32, 32)

    def load_dataset_from_disk(self, path, download):
        """Return torch.utils.data.Dataset instance, full dataset.

        Parameters:
        ===========
        path: str path to data on disk.
        download: bool whether to download the data to disk or not.
        """
        return torchvision.datasets.SVHN(
            path, split="train", download=download,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor()
            ])
        )

    def load_testset_from_disk(self, path, download):
        """Return a torch.utils.data.Dataset instance, the full test set.

        Parameters:
        ===========
        path: str path to data on disk.
        download: bool whether to download the data to disk or not.
        """
        return torchvision.datasets.SVHN(
            path, split="test", download=download,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor()
            ])
        )