from typing import Tuple

import torch
from numpy import random
from torch import Tensor
from torchvision.transforms import functional


class RandomAffine(torch.nn.Module):

    def __init__(self, rotation_factor: Tuple[float, float], rotation_probability: float, translation_factor: Tuple[int, int], translation_probability: float, scaling_factor: Tuple[float, float], scaling_probability: float, shear_factor: Tuple[float, float, float, float], shear_probability: float, interpolation=functional.InterpolationMode.NEAREST) -> None:
        super().__init__()
        self.rotation_factor = rotation_factor
        self.rotation_probability = rotation_probability
        self.translation_factor = translation_factor
        self.translation_probability = translation_probability
        self.scaling_factor = scaling_factor
        self.scaling_probability = scaling_probability
        self.shear_factor = shear_factor
        self.shear_probability = shear_probability
        self.interpolation = interpolation

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(rotation_factor={self.rotation_factor}, rotation_probability={self.rotation_probability}, translation_factor={self.translation_factor}, translation_probability={self.translation_probability}, scaling_factor={self.scaling_factor}, scaling_probability={self.scaling_probability}, shear_factor={self.shear_factor}, shear_probability={self.shear_probability}, interpolation={self.interpolation})"

    def forward(self, image: Tensor) -> Tensor:
        rotation = 0.0
        translation = (0, 0)
        scaling = 1.0
        shear = (0.0, 0.0)
        if (random.uniform(0.0, 1.0) < self.rotation_probability):
            rotation = random.uniform(self.rotation_factor[0], self.rotation_factor[1])
        if (random.uniform(0.0, 1.0) < self.translation_probability):
            translation = (random.randint(-self.translation_factor[0], self.translation_factor[0] + 1), random.randint(-self.translation_factor[1], self.translation_factor[1] + 1))
        if (random.uniform(0.0, 1.0) < self.scaling_probability):
            scaling = random.uniform(self.scaling_factor[0], self.scaling_factor[1])
        if (random.uniform(0.0, 1.0) < self.shear_probability):
            shear = (random.uniform(self.shear_factor[0], self.shear_factor[1]), random.uniform(self.shear_factor[2], self.shear_factor[3]))
        image = functional.affine(image, rotation, translation, scaling, shear, interpolation=self.interpolation)
        return image


class RandomRotation(torch.nn.Module):

    def __init__(self, factor: Tuple[float, float], probability: float = 0.5, interpolation=functional.InterpolationMode.NEAREST) -> None:
        super().__init__()
        self.factor = factor
        self.probability = probability
        self.interpolation = interpolation

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(factor={self.factor}, probability={self.probability}, interpolation={self.interpolation})"

    def forward(self, image: Tensor) -> Tensor:
        if (random.uniform(0.0, 1.0) < self.probability):
            image = functional.affine(image, random.uniform(*self.factor), (0, 0), 1.0, (0.0, 0.0), interpolation=self.interpolation)
        return image


class RandomTranslation(torch.nn.Module):

    def __init__(self, factor: Tuple[int, int], probability: float = 0.5, interpolation=functional.InterpolationMode.NEAREST) -> None:
        super().__init__()
        self.factor = factor
        self.probability = probability
        self.interpolation = interpolation

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(factor={self.factor}, probability={self.probability}, interpolation={self.interpolation})"

    def forward(self, image: Tensor) -> Tensor:
        if (random.uniform(0.0, 1.0) < self.probability):
            image = functional.affine(image, 0.0, (random.randint(-self.factor[0], self.factor[0] + 1), random.randint(-self.factor[1], self.factor[1] + 1)), 1.0, (0.0, 0.0), interpolation=self.interpolation)
        return image


class RandomScaling(torch.nn.Module):

    def __init__(self, factor: Tuple[float, float], probability: float = 0.5, interpolation=functional.InterpolationMode.NEAREST) -> None:
        super().__init__()
        self.factor = factor
        self.probability = probability
        self.interpolation = interpolation

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(factor={self.factor}, probability={self.probability}, interpolation={self.interpolation})"

    def forward(self, image: Tensor) -> Tensor:
        if (random.uniform(0.0, 1.0) < self.probability):
            image = functional.affine(image, 0.0, (0, 0), random.uniform(*self.factor), (0.0, 0.0), interpolation=self.interpolation)
        return image


class RandomShear(torch.nn.Module):

    def __init__(self, factor: Tuple[float, float, float, float], probability: float = 0.5, interpolation=functional.InterpolationMode.NEAREST) -> None:
        super().__init__()
        self.factor = factor
        self.probability = probability
        self.interpolation = interpolation

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(factor={self.factor}, probability={self.probability}, interpolation={self.interpolation})"

    def forward(self, image: Tensor) -> Tensor:
        if (random.uniform(0.0, 1.0) < self.probability):
            image = functional.affine(image, 0.0, (0, 0), 1.0, (random.uniform(self.factor[0], self.factor[1]), random.uniform(self.factor[2], self.factor[3])), interpolation=self.interpolation)
        return image


class RandomBrightness(torch.nn.Module):

    def __init__(self, factor: Tuple[float, float], probability: float = 0.5) -> None:
        super().__init__()
        self.factor = factor
        self.probability = probability

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(factor={self.factor}, probability={self.probability})"

    def forward(self, image: Tensor) -> Tensor:
        if (random.uniform(0.0, 1.0) < self.probability):
            image = functional.adjust_brightness(image, 1.0 + random.uniform(*self.factor))
        return image


class RandomSaturation(torch.nn.Module):

    def __init__(self, factor: Tuple[float, float], probability: float = 0.5) -> None:
        super().__init__()
        self.factor = factor
        self.probability = probability

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(factor={self.factor}, probability={self.probability})"

    def forward(self, image: Tensor) -> Tensor:
        if (random.uniform(0.0, 1.0) < self.probability):
            image = functional.adjust_saturation(image, 1.0 + random.uniform(*self.factor))
        return image


class RandomContrast(torch.nn.Module):

    def __init__(self, factor: Tuple[float, float], probability: float = 0.5) -> None:
        super().__init__()
        self.factor = factor
        self.probability = probability

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(factor={self.factor}, probability={self.probability})"

    def forward(self, image: Tensor) -> Tensor:
        if (random.uniform(0.0, 1.0) < self.probability):
            image = functional.adjust_contrast(image, 1.0 + random.uniform(*self.factor))
        return image


class RandomSharpness(torch.nn.Module):

    def __init__(self, factor: Tuple[float, float], probability: float = 0.5) -> None:
        super().__init__()
        self.factor = factor
        self.probability = probability

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(factor={self.factor}, probability={self.probability})"

    def forward(self, image: Tensor) -> Tensor:
        if (random.uniform(0.0, 1.0) < self.probability):
            image = functional.adjust_sharpness(image, 1.0 + random.uniform(*self.factor))
        return image


class RandomHue(torch.nn.Module):

    def __init__(self, factor: Tuple[float, float], probability: float = 0.5) -> None:
        super().__init__()
        self.factor = factor
        self.probability = probability

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(factor={self.factor}, probability={self.probability})"

    def forward(self, image: Tensor) -> Tensor:
        if (random.uniform(0.0, 1.0) < self.probability):
            image = functional.adjust_hue(image, random.uniform(*self.factor))
        return image


class RandomPosterize(torch.nn.Module):

    def __init__(self, bits: Tuple[float, float], probability: float = 0.5) -> None:
        super().__init__()
        self.bits = bits
        self.probability = probability

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(bits={self.bits}, probability={self.probability})"

    def forward(self, image: Tensor) -> Tensor:
        if (random.uniform(0.0, 1.0) < self.probability):
            if isinstance(image, torch.Tensor):
                image = functional.posterize((image * 255).type(torch.uint8), random.randint(self.bits[0], self.bits[1] + 1)).type(torch.float32) / 255
            else:
                image = functional.posterize(image, random.randint(self.bits[0], self.bits[1] + 1))
        return image


class RandomSolarize(torch.nn.Module):

    def __init__(self, threshold: Tuple[float, float], probability: float = 0.5) -> None:
        super().__init__()
        self.threshold = threshold
        self.probability = probability

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(threshold={self.threshold}, probability={self.probability})"

    def forward(self, image: Tensor) -> Tensor:
        if (random.uniform(0.0, 1.0) < self.probability):
            if isinstance(image, torch.Tensor):
                image = functional.solarize(image, random.uniform(float(self.threshold[0]), float(self.threshold[1])))
            else:
                image = functional.solarize(image, random.uniform(float(self.threshold[0] * 255), float(self.threshold[1] * 255)))
        return image


class RandomEqualize(torch.nn.Module):

    def __init__(self, probability: float = 0.5) -> None:
        super().__init__()
        self.probability = probability

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(probability={self.probability})"

    def forward(self, image: Tensor) -> Tensor:
        if (random.uniform(0.0, 1.0) < self.probability):
            if isinstance(image, torch.Tensor):
                image = functional.equalize((image * 255).type(torch.uint8)).type(torch.float32) / 255
            else:
                image = functional.equalize(image)
        return image
