from datasets import load_dataset
from torch.utils.data import IterableDataset
import json


class ImageNet1k(IterableDataset):
    def __init__(self, transform=None):
        super().__init__()
        path = "data/imagenet1k/*.tar"
        self.base_dataset = load_dataset(
            "webdataset", data_files={"val": path}, split="val", streaming=True)

        self.transform = transform

        self.build_class_to_idx()

    def __iter__(self):
        dataset_iterator = iter(self.base_dataset)
        while True:
            try:
                batch = next(dataset_iterator)
                image = batch["jpg"]
                label = batch["cls"]

                if self.transform:
                    image = self.transform(image)
                yield image, label
            except StopIteration:
                break

    def build_class_to_idx(self):
        class_name_file = "data/imagenet1k/classnames.txt"
        self.class_to_idx = {}
        with open(class_name_file, "r") as f:
            for idx, line in enumerate(f):
                class_name = line.strip()
                if class_name in self.class_to_idx:
                    class_name += "_"
                self.class_to_idx[class_name] = idx
        print(len(self.class_to_idx))


if __name__ == "__main__":
    dataset = ImageNet1k()