import torch
import random


def add_salt_and_pepper_noise_varied(image, noise_prob=0.1, salt_prob=0.5):
    noisy_image = image.clone()

    if len(image.shape) == 4:
        for i in range(image.shape[0]):
            torch.manual_seed(torch.randint(0, 100000, (1,)).item())

            img = image[i]
            noise_mask = torch.rand_like(img) < noise_prob
            salt_mask = torch.rand_like(img) < salt_prob

            salt_positions = noise_mask & salt_mask
            pepper_positions = noise_mask & (~salt_mask)

            noisy_image[i][salt_positions] = 1.0
            noisy_image[i][pepper_positions] = 0.0
    else:
        noise_mask = torch.rand_like(image) < noise_prob
        salt_mask = torch.rand_like(image) < salt_prob

        salt_positions = noise_mask & salt_mask
        pepper_positions = noise_mask & (~salt_mask)

        noisy_image[salt_positions] = 1.0
        noisy_image[pepper_positions] = 0.0

    return noisy_image


def add_salt_and_pepper_noise(image, noise_prob=0.1, salt_prob=0.5):
    torch.manual_seed(torch.randint(0, 100000, (1,)).item())

    noisy_image = image.clone()

    noise_mask = torch.rand_like(image) < noise_prob

    salt_mask = torch.rand_like(image) < salt_prob

    salt_positions = noise_mask & salt_mask
    noisy_image[salt_positions] = 1.0

    pepper_positions = noise_mask & (~salt_mask)
    noisy_image[pepper_positions] = 0.0

    return noisy_image


def add_gaussian_noise(image, mean=0.0, std=0.1):
    torch.manual_seed(torch.randint(0, 100000, (1,)).item())

    noise = torch.randn_like(image) * std + mean
    noisy_image = image + noise
    return torch.clamp(noisy_image, 0.0, 1.0)


def add_mixed_noise(
    image, noise_types=["salt_pepper", "gaussian"], noise_probs=[0.5, 0.5]
):
    noisy_image = image.clone()

    for noise_type, prob in zip(noise_types, noise_probs):
        if random.random() < prob:
            if noise_type == "salt_pepper":
                noisy_image = add_salt_and_pepper_noise(
                    noisy_image, noise_prob=0.5
                )
            elif noise_type == "gaussian":
                noisy_image = add_gaussian_noise(noisy_image, std=0.1)

    return noisy_image


class NoiseTransform:
    def __init__(self, noise_type="salt_pepper", noise_prob=0.1, **kwargs):
        self.noise_type = noise_type
        self.noise_prob = noise_prob
        self.kwargs = kwargs

    def __call__(self, image):
        torch.manual_seed(torch.randint(0, 100000, (1,)).item())

        if self.noise_type == "salt_pepper":
            return add_salt_and_pepper_noise(
                image,
                noise_prob=self.noise_prob,
                salt_prob=self.kwargs.get("salt_prob", 0.5),
            )
        elif self.noise_type == "gaussian":
            return add_gaussian_noise(
                image,
                mean=self.kwargs.get("mean", 0.0),
                std=self.kwargs.get("std", 0.1),
            )
        elif self.noise_type == "mixed":
            return add_mixed_noise(
                image,
                noise_types=self.kwargs.get(
                    "noise_types", ["salt_pepper", "gaussian"]
                ),
                noise_probs=self.kwargs.get("noise_probs", [0.5, 0.5]),
            )
        else:
            return image
