import torchvision
from torchvision import transforms
from .randaugment import RandAugmentMC


class DoubleTransforms:

    def __init__(self, size):
        self.weak_augmentation = transforms.Compose(
	        [
				transforms.Resize(size),
		        transforms.RandomHorizontalFlip(),
		        transforms.RandomCrop(size=size,
		                              padding=int(size * 0.125),
		                              padding_mode='reflect'),
	            transforms.ToTensor(),
	            transforms.Normalize(
	                (.48,.07,.02,), (.43,.77,.87,)
	            ),
	        ]
        )
        self.strong_augmentation = transforms.Compose(
	        [
				transforms.Resize(size),
		        transforms.RandomHorizontalFlip(),
		        transforms.RandomCrop(size=size,
		                              padding=int(size * 0.125),
		                              padding_mode='reflect'),
		        RandAugmentMC(n=2, m=10),
	            transforms.ToTensor(),
	            transforms.Normalize(
	                (.48,.07,.02,), (.43,.77,.87,)
	            ),
	        ]
        )

    def __call__(self, x):
        return self.weak_augmentation(x), self.strong_augmentation(x)

