import torchvision.transforms as T
import numpy as np
import torch
from torchvision.transforms.functional import to_pil_image

class IMAGENETTransform():
    def __init__(self, config, normalize, is_train=True, transform_single=True):
        image_size = config.dataset.image_size
        self.transform_single = transform_single
        # self.not_aug_transform = T.Compose([#T.Scale(256),
        #         T.CenterCrop(image_size), T.ToTensor()])
        # p_blur = 0.5 if image_size > 32 else 0 # exclude cifar
        if is_train:
            self.transform = T.Compose([
                #T.ToPILImage(),
                T.Lambda(lambda img: to_pil_image(img) if torch.is_tensor(img) else img),
                T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
                # T.RandomResizedCrop(image_size),
                T.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
                T.RandomHorizontalFlip(),
                
                # T.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
                # T.RandomHorizontalFlip(),
                # T.RandomApply([T.ColorJitter(0.4,0.4,0.4,0.1)], p=0.8),
                # T.RandomGrayscale(p=0.2),
                # T.RandomApply([T.GaussianBlur(kernel_size=image_size//20*2+1, sigma=(0.1, 2.0))], p=p_blur),
                T.ToTensor(),
                T.Normalize(*normalize)
            ])
        else:
            self.transform = T.Compose([
                T.Lambda(lambda img: to_pil_image(img) if torch.is_tensor(img) else img),
                T.CenterCrop(image_size),
                T.ToTensor(),
                T.Normalize(*normalize)
        ])

    def __call__(self, x):
        aug_x = self.transform(x)
        return aug_x

class SimSiamTransform:
    def __init__(self, config, normalize, is_train=True, transform_single=True):
        self.transform_single = transform_single
        image_size = config.dataset.image_size
        p_blur = 0.5 if image_size > 32 else 0 # exclude cifar
        # the paper didn't specify this, feel free to change this value
        # I use the setting from simclr which is 50% chance applying the gaussian blur
        # the 32 is prepared for cifar training where they disabled gaussian blur
        # self.not_aug_transform = T.Compose([T.ToTensor()])
        # self.not_aug_transform = T.Compose([#T.Scale(256),
        #         T.Lambda(lambda img: to_pil_image(img) if torch.is_tensor(img) else img),
        #         T.CenterCrop(config.dataset.image_size), 
        #         T.ToTensor()])
        
        if is_train:
            self.transform = T.Compose([
                # T.RandomCrop(32, padding=4),
                T.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
                T.RandomHorizontalFlip(),
                T.RandomApply([T.ColorJitter(0.4,0.4,0.4,0.1)], p=0.8),
                T.RandomGrayscale(p=0.2),
                T.RandomApply([T.GaussianBlur(kernel_size=image_size//20*2+1, sigma=(0.1, 2.0))], p=p_blur),
                T.ToTensor(),
                T.Normalize(*normalize)
            ])
        else:
            self.transform = T.Compose([
                T.Lambda(lambda img: to_pil_image(img) if torch.is_tensor(img) else img),
                #T.Scale(256),
                T.CenterCrop(config.dataset.image_size),
                T.ToTensor(),
                T.Normalize(*normalize)
        ])
        
    def __call__(self, x):
        if self.transform_single:
            x1 = self.transform(x)
            # x2 = self.transform(x)
            return x1
        else:
            x1 = self.transform(x)
            x2 = self.transform(x)
            # not_aug_x = self.not_aug_transform(x)                            
            return x1, x2#, x


class MaskGenerator:
    def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
        self.input_size = input_size
        self.mask_patch_size = mask_patch_size
        self.model_patch_size = model_patch_size
        self.mask_ratio = mask_ratio
        
        assert self.input_size % self.mask_patch_size == 0
        assert self.mask_patch_size % self.model_patch_size == 0
        
        self.rand_size = self.input_size // self.mask_patch_size # 6
        self.scale = self.mask_patch_size // self.model_patch_size # 8
        
        self.token_count = self.rand_size ** 2 # 36
        self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) # 21.6 -> 22
        
    def __call__(self):
        mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
        mask = np.zeros(self.token_count, dtype=int)
        mask[mask_idx] = 1
        
        mask = mask.reshape((self.rand_size, self.rand_size)) # (6, 6)
        mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) # (48, 48)
        return mask         

class NonBinaryMaskGenerator:
    def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6, sample='uniform'):
        self.input_size = input_size
        self.mask_patch_size = mask_patch_size
        self.model_patch_size = model_patch_size
        self.mask_ratio = mask_ratio
        self.sample = sample
        
        assert self.input_size % self.mask_patch_size == 0
        assert self.mask_patch_size % self.model_patch_size == 0
        
        self.rand_size = self.input_size // self.mask_patch_size # 6
        self.scale = self.mask_patch_size // self.model_patch_size # 8
        
        self.token_count = self.rand_size ** 2 # 36
        self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) # 21.6 -> 22
        
    def __call__(self):
        # NOTE hyp_parameter is customized for self.mask_ratio == 0.6, self.token_count == 36
        _mean, _std = 0.4, 0.5
        mask = _mean + _std * np.random.randn(self.token_count)
        # mask = np.random.rand(self.token_count)
        # mask_sum = sum(mask)
        # mask *= (self.mask_ratio*self.token_count/mask_sum)
        
        mask = mask.reshape((self.rand_size, self.rand_size)) # (6, 6)
        mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) # (48, 48)
        return mask             
        
class SimMIMTransform:
    def __init__(self, config, normalize, is_train=True, transform_single=True):
        # self.not_aug_transform = T.Compose([#T.Scale(256),
        #         T.Lambda(lambda img: to_pil_image(img) if torch.is_tensor(img) else img),
        #         T.CenterCrop(config.dataset.image_size), 
        #         T.ToTensor()])
        self.transform_single = transform_single
        
        if is_train:
            self.transform = T.Compose([
                T.Lambda(lambda img: to_pil_image(img) if torch.is_tensor(img) else img),
                T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
                T.RandomResizedCrop(config.dataset.image_size, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                T.Normalize(*normalize),
                #T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)),
            ])
            
            if config.model.type == 'swin':
                model_patch_size=config.model.swin.patch_size
            elif config.model.type == 'vit':
                model_patch_size=config.model.vit.patch_size
            else:
                raise NotImplementedError
            
            self.mask_generator = MaskGenerator(
                input_size=config.dataset.image_size,
                mask_patch_size=config.model.mask_patch_size,
                model_patch_size=model_patch_size,
                mask_ratio=config.model.mask_ratio,
            )
        else:
            self.transform = T.Compose([
                T.Lambda(lambda img: to_pil_image(img) if torch.is_tensor(img) else img),
                #T.Scale(256),
                T.CenterCrop(config.dataset.image_size),
                T.ToTensor(),
                T.Normalize(*normalize)
        ])
        
    def __call__(self, img):
        """
        img = self.transform_img(img)
        mask = self.mask_generator()
        
        return img, mask
        """
        mask = self.mask_generator()
        
        if self.transform_single:
            aug_x = self.transform(img)
            return aug_x#, mask
        else:
            # if torch.distributed.get_rank() == 0:               
            #     import pdb; pdb.set_trace()
            # torch.distributed.barrier()
            aug_x = self.transform(img)
            return aug_x, mask


