import torchvision

import dataset


class Cifar100(dataset.Cifar10):

    NAME = "cifar100"

    MEAN = [0.5071, 0.4867, 0.4408]
    STD = [0.2675, 0.2565, 0.2761]

    def get_classes(self):
        """Return the int number of classes."""
        return 100

    def get_image_channel_mean(self):
        return Cifar100.MEAN

    def get_image_channel_std(self):
        return Cifar100.STD

    def get_dataset_class(self):
        return torchvision.datasets.CIFAR100