from PIL import ImageFilter
import random
import torchvision.transforms as transforms
from moco.augmentations import RandAugment2
import pdb

class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x


class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform_w, base_transform_s):
        self.base_transform_w = base_transform_w
        self.base_transform_s = base_transform_s

    def __call__(self, x):
        k = self.base_transform_w(x)
        q = self.base_transform_s(x)
        return [q, k]

class ThreeCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform_w, base_transform_s):
        self.base_transform_w = base_transform_w
        self.base_transform_s = base_transform_s

    def __call__(self, x):
        q = self.base_transform_w(x)
        k = self.base_transform_w(x)
        s = self.base_transform_s(x)
        return [q, k, s]

class ThreeCropsTransform_moco:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform_w, base_transform_s):
        self.base_transform_w = base_transform_w
        self.base_transform_s = base_transform_s

    def __call__(self, x):
        q = self.base_transform_w(x)
        k = self.base_transform_w(x)
        # s = self.base_transform_s(x)
        return [q, k]

class MultiCropsTransform:
    """Take two random crops of one image as the query and key."""
    def __init__(self, base_transform_w, base_transform_s):
        self.base_transform_w = base_transform_w
        self.base_transform_s = base_transform_s
        # self.strong_crop_num =strong_crop_num

    def __call__(self, x):
        q_w = self.base_transform_w(x) # size 224, weak transform
        k_w = self.base_transform_w(x) # size 224, weak transform

        # assert 1 <= self.strong_crop_num and self.strong_crop_num <= len(self.base_transform_s)
        qs_s = []
        for i in range(len(self.base_transform_s)):
            qs_s.append(self.base_transform_s[i](x))

        return [q_w, k_w, qs_s]


def get_augmentation_weak1_strong1(sizes_w=224, sizes_s=224, scale_flag=False, strong_aug=False):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709

    augmentation_w = transforms.Compose([
            # RandAugment2(),
            transforms.RandomResizedCrop(sizes_w, scale=(0.2, 1.)),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # normalize
        ])
    if scale_flag:
        scale1, scale2 = 0.05, 0.2
    else:
        scale1, scale2 = 0.2, 1.

    if strong_aug:
        augmentation_s = transforms.Compose([
                transforms.RandomResizedCrop(sizes_s, scale=(scale1, scale2)),
                RandAugment2(),
                RandAugment2(),
                RandAugment2(),
                RandAugment2(),
                RandAugment2(),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
                ], p=0.8),
                transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
                transforms.RandomHorizontalFlip(),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                # normalize
            ])
    else:
        augmentation_s = transforms.Compose([
            transforms.RandomResizedCrop(sizes_s, scale=(scale1, scale2)),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # normalize
        ])

    return augmentation_w, augmentation_s



def get_augmentation(sizes_w=224, sizes_s=[224,192,160,128,96], scale_flag=False, strong_aug=False):
    # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    augmentation_w = transforms.Compose([
            transforms.RandomResizedCrop(sizes_w, scale=(0.2, 1.)),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # normalize
        ])

    augmentation_s = []
    for i in range(len(sizes_s)):
        if scale_flag:
            scale1, scale2 = sizes_s[i] / 224.0 * 0.2, sizes_s[i] / 224.0 * 1.0
        else:
            scale1, scale2 = 0.2, 1.

        if strong_aug:
            temp = transforms.Compose([
                transforms.RandomResizedCrop(sizes_s[i], scale=(scale1, scale2)),
                RandAugment2(),
                RandAugment2(),
                RandAugment2(),
                RandAugment2(),
                RandAugment2(),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
                ], p=0.8),
                transforms.RandomGrayscale(p=0.2),
                transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                # normalize
            ])
        else:
            temp = transforms.Compose([
                transforms.RandomResizedCrop(sizes_s[i], scale=(scale1, scale2)),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
                ], p=0.8),
                transforms.RandomGrayscale(p=0.2),
                transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                # normalize
            ])
        augmentation_s.append(temp)

    return augmentation_w, augmentation_s

