import torch
import numpy as np
from torch import nn, optim
import math
import random
from PIL import Image, ImageOps, ImageFilter
import torchvision
import torchvision.transforms as transforms


class GaussianBlur(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            sigma = random.random() * 1.9 + 0.1
            return img.filter(ImageFilter.GaussianBlur(sigma))
        else:
            return img


class Solarization(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            return ImageOps.solarize(img)
        else:
            return img


class Identity(object):
    def __init__(self):
        pass

    def __call__(self, img):
        return img


class HighPassFilter(object):
    def __init__(self, args=None, kernel_size=13, sigma=5):
        self.args = args
        self.kernel_size = kernel_size
        self.sigma = sigma

    def __call__(self, img):
        return img - transforms.functional.gaussian_blur(
            img, kernel_size=self.kernel_size, sigma=self.sigma
        )


def focal_mask(ratio):
    d = int(224 * np.sqrt(1-ratio))

    x = int((224-d) * np.random.random())
    y = int((224-d) * np.random.random())

    mask = torch.zeros(1, 224, 224)
    mask[0, x:x+d, y:y+d] = 1
    return mask


def grid_mask(ratio, s):
    n = 224 // s
    mask = torch.rand(1, n, n)
    mask = (mask > ratio).int()
    mask = mask.expand(1, s, s, n, n)
    mask = torch.transpose(mask, 1, 4)
    mask = torch.reshape(mask, shape=(1, 224, 224))
    return mask


class GridMasking(object):
    def __init__(self, args):
        self.args = args
        self.r = self.args.mask_ratio
        self.n = 224 // self.args.grid_size
        self.s = self.args.grid_size

    def generate_mask(self):
        if torch.rand(1) < self.args.color:
            masks = []

            if torch.rand(1) < self.args.focal:
                for _ in range(3):
                    masks.append(focal_mask(self.r))
            else:
                for _ in range(3):
                    masks.append(grid_mask(self.r, self.s))
            
            mask = torch.cat(masks)


        else:
            if torch.rand(1) < self.args.focal:
                mask = focal_mask(self.r)
            else:
                mask = grid_mask(self.r, self.s)

        return mask 


class Standard_Transform:
    def __init__(self, args):
        self.args = args
        self.baseline = transforms.Compose([
            transforms.RandomResizedCrop(224, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
        ])
        self.no_blur = transforms.Compose([
            transforms.RandomResizedCrop(224, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
        ])
        self.crop_only = transforms.Compose([
            transforms.RandomResizedCrop(224, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        if self.args.augmentation == "baseline":
            y1 = self.baseline(x)
            y2 = self.baseline(x)
        elif self.args.augmentation == "no-blur":
            y1 = self.no_blur(x)
            y2 = self.no_blur(x)
        elif self.args.augmentation == "crop-only":
            y1 = self.crop_only(x)
            y2 = self.crop_only(x)
        return y1, y2


class Multimask_Transform:
    def __init__(self, args):
        self.args = args
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.4, 1.0),
                                         interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=1.0),
            Solarization(p=0.0),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            HighPassFilter(args)
        ])
        self.transform_prime = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.1, 1.0),
                                         interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=0.1),
            Solarization(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            HighPassFilter(args)
        ])

        self.mask = GridMasking(args)


    def apply_mask(self, x):
        mask = self.mask.generate_mask()

        std = torch.rand(1) * 0.2
        z1 = torch.randn(3, 224, 224) * std
        z2 = torch.randn(1) * 1.2

        return mask * x + (1-mask) * z1 + z2


    def __call__(self, x):
        
        x1 = self.transform(x)
        x2 = self.transform(x)
        x3 = self.transform_prime(x)
        x4 = self.transform_prime(x)

        y1 = self.apply_mask(x1)
        y2 = self.apply_mask(x2)
        y3 = self.apply_mask(x3)
        y4 = self.apply_mask(x4)

        return y1, y2, y3, y4
        

class Eval_Transform:
    def __init__(self, args):
        self.args = args
        self.transform = transforms.Compose([transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            Identity() if args.aug else HighPassFilter()]
        )

    def __call__(self, x):
        x = self.transform(x)
        return x
