r"""
Preparing the data for CLIP score.
"""

from datasets import Array3D, Features, load_dataset
from pathlib import Path
import json

class Counter:
    def __init__(self, start = 0):
        self.counter = start
    def get_value(self):
        return self.counter
    def inc(self):
        self.counter += 1


def main(dataset_path: Path, save_path: Path):
    if not dataset_path.exists():
        raise "Invalid dataset path."
    
    dataset = load_dataset('cifar10', cache_dir=dataset_path)
    
    img_path = save_path / 'images'

    if not save_path.exists():
        save_path.mkdir(parents = True)
    if not img_path.exists():
        img_path.mkdir(parents = True)

    counter = Counter()
    img_to_label = {}

    def transform(row):
        label = row['label']
        img = row['img']
        img.save(img_path / f'{counter.get_value()}.jpg')
        img_to_label[f'{counter.get_value()}'] = label
        counter.inc()


    dataset = dataset.map(
        transform,
        keep_in_memory=True,
        num_proc=1,
    )
    
    with open(save_path / 'labels.json', 'w') as file:
        json.dump(img_to_label, file, indent=4)

if __name__ == "__main__":
    dataset_path = Path('/data/vision/___/scratch/___ht/cifar_dir/hf')
    save_path = Path('/data/vision/___/scratch/___ht/cifar_dir/clipscore_data')
    main(dataset_path, save_path)

