import torch
from torch.utils.data import Dataset
from torchvision import transforms  


class RandomImageDataset(Dataset):
    def __init__(
        self,
        vit_transforms,
        mean=0.0,
        std=1.0,
        size=(3, 224, 224),
        num_images=1_300_000,
        seed=None,
    ):
        self.mean = mean
        self.std  = std
        self.size = size
        self.n    = num_images
        self.tfm  = vit_transforms
        g = torch.Generator()
        if seed is not None:
            g.manual_seed(seed)
        self.noise_gen = lambda: torch.randn(*size, generator=g) * std + mean
        self.to_pil = transforms.ToPILImage()

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        img = self.noise_gen().clamp(-3*self.std, 3*self.std)  
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)
        return self.tfm(self.to_pil(img)), 0