# -*-coding:utf-8-*-
import math
import random

import numpy as np
import torch
from PIL import Image
from PIL import ImageFilter, ImageOps
from einops import rearrange
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import transforms


class GaussianBlur:
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x


class Solarize:
    """Solarize augmentation from BYOL: https://arxiv.org/abs/2006.07733"""

    def __call__(self, x):
        return ImageOps.solarize(x)


class MultiViewAugmentation:
    def __init__(self, args):
        self.mix_n = args.mix_n
        self.mix_n2 = args.mix_n2
        self.switch_p = args.switch_p
        self.base_transform1 = transforms.Compose([
            transforms.RandomResizedCrop(
                args.input_size, scale=(args.min_crop, 1.0), interpolation=Image.BICUBIC),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([.1, 2.])], p=1.0),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
        ]
        )
        self.base_transform2 = transforms.Compose([
            transforms.RandomResizedCrop(
                args.input_size, scale=(args.min_crop, 1.0), interpolation=Image.BICUBIC),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.1),
            transforms.RandomApply([Solarize()], p=0.2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
        ]
        )

        self.mix_transform1 = transforms.Compose([
            transforms.RandomResizedCrop(
                args.input_size // math.sqrt(self.mix_n), scale=(args.min_mix_crop, args.global_crop),
                interpolation=Image.BICUBIC),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([.1, 2.])], p=1.0),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
        ]
        )
        self.mix_transform2 = transforms.Compose([
            transforms.RandomResizedCrop(
                args.input_size // math.sqrt(self.mix_n), scale=(args.min_mix_crop, args.global_crop),
                interpolation=Image.BICUBIC),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.1),
            transforms.RandomApply([Solarize()], p=0.2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
        ]
        )

        self.mix_transform3 = transforms.Compose([
            transforms.RandomResizedCrop(
                args.input_size // math.sqrt(self.mix_n * self.mix_n2), scale=(args.min_mix_crop, args.global_crop),
                interpolation=Image.BICUBIC),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([.1, 2.])], p=1.0),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
        ]
        )

        self.mix_transform4 = transforms.Compose([
            transforms.RandomResizedCrop(
                args.input_size // math.sqrt(self.mix_n * self.mix_n2), scale=(args.min_mix_crop, args.global_crop),
                interpolation=Image.BICUBIC),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.1),
            transforms.RandomApply([Solarize()], p=0.2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
        ]
        )

    def __call__(self, x):
        img1 = self.base_transform1(x)
        img2 = self.base_transform2(x)
        imgs1 = []
        imgs2 = []
        for _ in range(self.mix_n):
            if np.random.rand() > self.switch_p:
                imgs1.append(self.mix_transform1(x))
                imgs2.append(self.mix_transform2(x))
            else:
                imgs3 = []
                imgs4 = []
                for _ in range(int(self.mix_n2)):
                    imgs3.append(self.mix_transform3(x).unsqueeze(0))
                    imgs4.append(self.mix_transform4(x).unsqueeze(0))
                mix_img3 = torch.cat(imgs3, dim=0)
                mix_img4 = torch.cat(imgs4, dim=0)
                mix_img3 = rearrange(mix_img3, '(t1 t2) c w h -> c (t1 w) (t2 h)', t1=int(math.sqrt(self.mix_n2)))
                mix_img4 = rearrange(mix_img4, '(t1 t2) c w h -> c (t1 w) (t2 h)', t1=int(math.sqrt(self.mix_n2)))
                imgs1.append(mix_img3)
                imgs2.append(mix_img4)

        return img1, img2, imgs1, imgs2
