import torchvision.transforms as transforms


normalization_dict = {'CIFAR10': [(0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)],
                      'ImageNet':[(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)]}

class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]

def CIFARTransform(augmentation=False, contrastive=False):
    normalize_transform = transforms.Normalize(normalization_dict['CIFAR10'][0], 
                                                normalization_dict['CIFAR10'][1])
    if not augmentation:
        return transforms.Compose([transforms.ToTensor(), normalize_transform])
    augmentations = [
                    transforms.RandomResizedCrop(32),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
                    transforms.RandomGrayscale(p=0.2),
                    transforms.ToTensor(),
                    normalize_transform]
    img_transforms = transforms.Compose(augmentations)
    if contrastive:
        img_transforms = TwoCropsTransform(img_transforms)
    return img_transforms



def KvasirTransform(augmentation=False, contrastive=False):
    if not augmentation:
        return None
    augmentations = [
                    transforms.RandomResizedCrop((80,100)),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
                    transforms.RandomGrayscale(p=0.2),
                    transforms.GaussianBlur(kernel_size = (9, 9))
                    ]
    img_transforms = transforms.Compose(augmentations)
    if contrastive:
        img_transforms = TwoCropsTransform(img_transforms)
    return img_transforms
