import random
from typing import Any, Dict, List

from PIL import ImageFilter, ImageOps

from torchvision.datasets import VisionDataset
from torchvision import transforms

class TransformDataset(VisionDataset):
    def __init__(self, dataset, transform=None):
        super().__init__()
        self.dataset = dataset  # Store the subset
        self.transform = transform  # Store the transform

    def __getitem__(self, index):
        image, label = self.dataset[index]  # Retrieve data from the subset
        if self.transform:
            image = self.transform(image)  # Apply the transform
        return image, label

    def __len__(self):
        return len(self.dataset)  # Return the length of the subset

class GaussianBlur(object):
    """
    Apply Gaussian Blur to the PIL image.
    """

    def __init__(self, p=0.1, radius_min=0.1, radius_max=2.0):
        self.prob = p
        self.radius_min = radius_min
        self.radius_max = radius_max

    def __call__(self, img):
        do_it = random.random() <= self.prob
        if not do_it:
            return img

        img = img.filter(
            ImageFilter.GaussianBlur(
                radius=random.uniform(self.radius_min, self.radius_max)
            )
        )
        return img


class Solarization(object):
    """
    Apply Solarization to the PIL image.
    """

    def __init__(self, p=0.2):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            return ImageOps.solarize(img)
        else:
            return img


class GrayScale(object):
    """
    Apply Solarization to the PIL image.
    """

    def __init__(self, p=0.2):
        self.p = p
        self.transf = transforms.Grayscale(3)

    def __call__(self, img):
        if random.random() < self.p:
            return self.transf(img)
        else:
            return img


class horizontal_flip(object):
    """
    Apply Solarization to the PIL image.
    """

    def __init__(self, p=0.2, activate_pred=False):
        self.p = p
        self.transf = transforms.RandomHorizontalFlip(p=1.0)

    def __call__(self, img):
        if random.random() < self.p:
            return self.transf(img)
        else:
            return img
