from torch.utils.data import Dataset

from datasets import load_dataset
import random

# Make sure to export HF_HOME for imagenet training
# export HF_HOME="..."

class ImageNet(Dataset):
    def __init__(self, root: str, train: bool = True, transform = None, target_transform = None, download: bool = False):
        super().__init__()
        self.root = root
        self.train = train
        self.download = download

        self.transform = transform
        self.target_transform = target_transform

        split = "train" if train else 'validation'

        self.dataset = load_dataset("imagenet-1k", split=split)

    def __getitem__(self, index: int):
        max_attempts = 10
        for _ in range(max_attempts):
            try:
                sample = self.dataset[index]
                image = sample["image"]
                if image.mode != "RGB":
                    image = image.convert("RGB")
                target = sample["label"]

                if self.transform is not None:
                    image = self.transform(image)

                if self.target_transform is not None:
                    target = self.target_transform(target)

                return image, target

            except Exception as e:
                print(f"Skipping corrupted image at index {index}: {e}")
                index = random.randint(0, len(self.dataset) - 1)

        raise RuntimeError("Too many corrupted images in a row.")

    def __len__(self):
        return len(self.dataset)

