import dataset


class MnistAugment(dataset.Mnist):

    NAME = "mnist-augment"

    def get_training_augmentation(self):
        # change 28x28 -> 32x32 with 2 pixel shift augmentation
        return torchvision.transforms.RandomCrop(32, padding=4)