from PIL import Image
import webdataset as wds


def streaming_dataset():
    path_pattern = "data/datacomp/data-000{00..80}.tar"

    txt_transform = lambda x: x['caption']

    dataset = (
        wds.WebDataset(path_pattern, workersplitter=wds.split_by_worker)
        .decode("pil", handler=wds.warn_and_continue)
        .to_tuple("jpg", "json", handler=wds.warn_and_continue)
        .map_tuple(None, txt_transform, handler=wds.warn_and_continue)
    )       
    return dataset

if __name__ == "__main__":
    ds = streaming_dataset()
    for i, example in enumerate(ds):
        # print(example.keys())
        if i >= 10:
            break
        example[0].save(f"{i}.png")
        print(example[1])
