from datasets import load_dataset
from torch.utils.data import IterableDataset


class ObjectNet(IterableDataset):
    def __init__(self, transform=None):
        super().__init__()
        path = "data/objectnet/*.tar"
        self.base_dataset = load_dataset(
            "webdataset", data_files={"val": path}, split="val", streaming=True)

        self.transform = transform

        self.build_class_to_idx()

        # print(self.class_to_idx)

    def __iter__(self):
        dataset_iterator = iter(self.base_dataset)
        while True:
            try:
                batch = next(dataset_iterator)
                image = batch["jpg"]
                label = batch["cls"]

                # exclude background class
                if label == 0:
                    continue

                if self.transform:
                    image = self.transform(image)
                yield image, label
            except StopIteration:
                break

    def build_class_to_idx(self):
        classes_path = "data/objectnet/classnames.txt"
        self.class_names = []
        self.class_to_idx = {}
        with open(classes_path, "r") as f:
            for idx, line in enumerate(f):
                class_name = line.strip()
                self.class_to_idx[class_name] = idx
                self.class_names.append(class_name)

if __name__ == "__main__":
    dataset = ObjectNet()

    iterator = iter(dataset)
    for i in range(5):
        image, label = next(iterator)
        image.save(f"objectnet_{i}_{label}.jpg")