from datasets import load_dataset
import logging

map_datasets = load_dataset("derek-thomas/ScienceQA")
# /XXXX-5/home-XXXX-3/data/MSD/datasets/_original/ScienceQA

# filter out samples w.o images
for split, dataset in map_datasets.items():
    len_before = len(dataset)
    map_datasets[split] = dataset.filter(lambda x: x['image'] is not None)
    logging.info(f"[Dataset] Filter out samples w.o images for {split} dataset: {len_before} -> {len(map_datasets[split])} samples")


map_datasets.save_to_disk("/XXXX-5/home-XXXX-3/data/MSD/datasets/ScienceQA")
# ds.save_to_disk("/XXXX-5/home-XXXX-3/data/MSD/datasets/ScienceQA")

"""
# save example
import pickle
with open('/XXXX-5/home-XXXX-3/data/MSD/datasets/ScienceQA/example/batch_example.pkl', 'wb') as f:
    pickle.dump(map_datasets['validation'][930], f)
"""


print()