import torch
from torchvision import transforms
from torchvision.transforms import functional
from PIL import Image

from .image_utils import calc_top_left_coordinates, paste_fg_on_bg


class RandomScale(torch.nn.Module):

    def __init__(self, min_size: float) -> None:
        super().__init__()
        self.min_size = min_size

    def forward(self, x: Image.Image) -> Image.Image:
        factor = (
            float(torch.rand(1).item())
            * (1 - self.min_size)
            + self.min_size
        )
        w, h = x.size
        sw, sh = int(round(w * factor)), int(round(h * factor))
        return x.resize((sw, sh)).convert("RGBA")

class RandomMove(torch.nn.Module):

    def __init__(self, img_size: int) -> None:
        super().__init__()
        self.img_size = img_size

    def forward(self, x: Image.Image) -> Image.Image:
        pos_x = float(torch.rand(1).item()) * 0.8 + 0.1
        pos_y = float(torch.rand(1).item()) * 0.8 + 0.1
        start_x, start_y = calc_top_left_coordinates(
            x, self.img_size, pos_x, pos_y,
        )
        bg = Image.new("RGBA", (self.img_size, self.img_size), (0, 0, 0, 0))
        return paste_fg_on_bg(x, bg, start_x, start_y)

class ColorJitterWithTransparency(torch.nn.Module):

    def __init__(
        self,
        brightness: float = 0,
        contrast: float = 0,
        saturation: float = 0,
        hue: float = 0,
    ) -> None:
        super().__init__()
        self.jitter = transforms.ColorJitter(
            brightness, contrast, saturation, hue
        )

    def forward(self, x: Image.Image) -> Image.Image:
        jittered_x = self.jitter(x)
        size_x, size_y = x.size
        bg = Image.new("RGBA", (size_x, size_y), (0, 0, 0, 0))
        # bg.paste(fg, box=(x_coord, y_coord), mask=fg)
        bg.paste(jittered_x, box=(0, 0), mask=x)
        return bg

TRANSFORMS = {
    "scale": lambda _: RandomScale(0.3),
    "move": lambda img_size: RandomMove(img_size),
    # "crop_32": transforms.RandomCrop(32), # needs to be first
    # "h_flip": transforms.RandomHorizontalFlip(), # too similar
    "v_flip": lambda _: transforms.RandomVerticalFlip(),
    "h_flip": lambda _: transforms.RandomHorizontalFlip(),
    "rot": lambda _: transforms.RandomRotation(degrees=(0, 360)),
    # "perspective": transforms.RandomPerspective(), # too siilar
    # "col_jitter": transforms.ColorJitter(
    #     brightness=0.2, saturation=0.5, hue=0.5,
    # ),
    "col_jitter": lambda _: ColorJitterWithTransparency(
        brightness=0.4, saturation=0.5, hue=0.5, contrast=0,
    ),
    "blur": lambda _: transforms.GaussianBlur(9),
    # "grayscale": transforms.RandomGrayscale(),
    # "posterize_2": transforms.RandomPosterize(bits=2),
    # "invert": transforms.RandomInvert(),
    "sharpen": lambda _: transforms.RandomAdjustSharpness(sharpness_factor=9),
}

def none_transform(img: torch.Tensor) -> torch.Tensor:
    return img
TRANSFORMS_WITH_NONE = {
    "none": lambda _: none_transform,
    **TRANSFORMS,
}

# UPPER_RIGHT_CORNER = transforms.Lambda(
#     lambda img: functional.affine(img, 0, [10, -10], 0.375, [0, 0])
# )
