import torchvision

import dataset


class Cifar10(dataset.DatapoolDisk):

    NAME = "cifar10"

    MEAN = [0.4914, 0.4824, 0.4467]
    STD = [0.2471, 0.2435, 0.2616]

    # === PROTECTED ===

    def get_classes(self):
        """Return the int number of classes."""
        return 10

    def get_input_dim(self):
        """Return tuple of int dimensions of input."""
        return (3, 32, 32)

    def load_dataset_from_disk(self, path, download):
        """Return torch.utils.data.Dataset instance, full dataset.

        Parameters:
        ===========
        path: str path to data on disk.
        download: bool whether to download the data to disk or not.
        """
        return self.get_dataset_class()(
            path, train=True, download=download,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.Compose(
                    self.create_training_augmentation()
                ),
                torchvision.transforms.ToTensor(),
                self._create_image_normalizer()
            ])
        )

    def create_training_augmentation(self):
        return []

    def get_image_channel_mean(self):
        return Cifar10.MEAN

    def get_image_channel_std(self):
        return Cifar10.STD

    def get_dataset_class(self):
        return torchvision.datasets.CIFAR10

    def _create_image_normalizer(self):
        return torchvision.transforms.Normalize(
            mean=self.get_image_channel_mean(),
            std=self.get_image_channel_std()
        )

    def load_testset_from_disk(self, path, download):
        """Return a torch.utils.data.Dataset instance, the full test set.

        Parameters:
        ===========
        path: str path to data on disk.
        download: bool whether to download the data to disk or not.
        """
        return self.get_dataset_class()(
            path, train=False, download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                self._create_image_normalizer()
            ])
        )