import random
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from PIL import ImageFilter
from .transforms import MultiSample, aug_transform
from .base import BaseDataset


class RandomBlur:
    def __init__(self, r0, r1):
        self.r0, self.r1 = r0, r1

    def __call__(self, image):
        r = random.uniform(self.r0, self.r1)
        return image.filter(ImageFilter.GaussianBlur(radius=r))


def base_transform():
    return T.Compose(
        [T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
    )


class ImageNet(BaseDataset):
    def ds_train(self):
        aug_with_blur = aug_transform(
            224,
            base_transform,
            self.aug_cfg,
            extra_t=[T.RandomApply([RandomBlur(0.1, 2.0)], p=0.5)],
        )
        t = MultiSample(aug_with_blur, n=self.aug_cfg.num_samples)
        return ImageFolder(root=self.aug_cfg.imagenet_path + "train", transform=t)

    def ds_clf(self):
        t = base_transform()
        return ImageFolder(root=self.aug_cfg.imagenet_path + "clf", transform=t)

    def ds_test(self):
        t = base_transform()
        return ImageFolder(root=self.aug_cfg.imagenet_path + "test", transform=t)
