import torchvision.transforms as transforms
from torchvision.transforms.autoaugment import AutoAugmentPolicy
from utils import *
cifar10_transform_tr = transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4),
                    transforms.AutoAugment(policy=AutoAugmentPolicy.CIFAR10),
                    transforms.ToTensor(),
                    Cutout(n_holes=1, length=16),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std =[0.229, 0.224, 0.225]),
                ])
cifar10_transform_ts = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std =[0.229, 0.224, 0.225])
                ])

cifar100_transform_tr = transforms.Compose([
					transforms.RandomCrop(32, padding=4),
					transforms.RandomHorizontalFlip(),
					transforms.RandomRotation(15),
                    transforms.AutoAugment(policy=AutoAugmentPolicy.CIFAR10),
					transforms.ToTensor(),
                    Cutout(n_holes=1, length=16),
					transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
										 std =[0.2673, 0.2564, 0.2762])
				])
cifar100_transform_ts = transforms.Compose([
					transforms.ToTensor(),
					transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
										 std =[0.2673, 0.2564, 0.2762])
				])

class TwoCropTransform:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]


