import torchvision

import dataset


class Mnist(dataset.DatapoolDisk):

    NAME = "mnist"

    # === PROTECTED ===

    def get_classes(self):
        """Return the int number of classes."""
        return 10

    def get_input_dim(self):
        """Return tuple of int dimensions of input."""
        return (1, 32, 32)

    def load_dataset_from_disk(self, path, download):
        """Return torch.utils.data.Dataset instance, full dataset.

        Parameters:
        ===========
        path: str path to data on disk.
        download: bool whether to download the data to disk or not.
        """
        return torchvision.datasets.MNIST(
            path, train=True, download=download,
            transform=torchvision.transforms.Compose([
                self.get_training_augmentation(),
                torchvision.transforms.ToTensor()
            ])
        )

    def get_training_augmentation(self):
        return torchvision.transforms.Pad(2)

    def load_testset_from_disk(self, path, download):
        """Return a torch.utils.data.Dataset instance, the full test set.

        Parameters:
        ===========
        path: str path to data on disk.
        download: bool whether to download the data to disk or not.
        """
        return torchvision.datasets.MNIST(
            path, train=False, download=download,
            transform=torchvision.transforms.Compose([
                # change 28x28 -> 32x32, centered
                torchvision.transforms.Pad(2),
                torchvision.transforms.ToTensor()
            ])
        )