from PIL import Image
from datasets import load_from_disk
from pathlib import Path

def load_samples_to_dir(dataset_path: str, save_path: str):
    ds = load_from_disk(dataset_path)
    ds = ds.select(range(30))
    ds.set_format('numpy', dtype='uint8')
    images = [Image.fromarray(image_arr) for image_arr in ds['y']]
    folder_to_save = Path(save_path)
    folder_to_save.mkdir(parents=True, exist_ok=True)
    for idx, image in enumerate(images):
        image.save(folder_to_save / f'{idx}.png')

if __name__ == '__main__':

    save_path = '/data/vision/___/scratch/___ht/diffusion-priors/experiments/celeba/samples_mask75/'
    dataset_path_75 = '/data/vision/___/scratch/___ht/celeba_64_mask75/'
    dataset_path_0 = '/data/vision/___/scratch/___ht/celeba_64_mask0/'
    # load_samples_to_dir(dataset_path_75, save_path)
    load_samples_to_dir(dataset_path_0, save_path)