"""
This script processes a dataset of images encoded in base64 format and saves them to a specified output directory.

Usage:
    python seed_dataset.py \
        --input_path <path_to_input_json> \
        --output_dir <path_to_output_directory> \
        [--max_workers <number_of_workers>] \
        [--filename_prefix <key1> <key2> ...]

Arguments:
    --input_path: Path to the input JSON file containing the dataset.
    --output_dir: Directory where the images will be saved.
    --max_workers: (Optional) Number of parallel workers to use. Default is 8.
    --filename_prefix: (Optional) List of keys to use as a prefix for the filenames.

Example:
    python src/turtlegfx_datagen/utils/seed_dataset.py \
        --input_path "exps/datasetlabelling/results/dataset_graphics_sz721_labelled_1222.json" \
        --output_dir "exps/datasetlabelling/results/dataset_graphics_sz721_labelled_1222" \
        --filename_prefix category difficulty
"""

import json
import os
from src.turtlegfx.utils.base64img import convert_base64_to_img
from src.turtlegfx_datagen.utils.img_utils import get_image_filename_from_id
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

def save_image(item, output_dir, filename_prefix=None):
    img = convert_base64_to_img(item["task_image"])
    os.makedirs(output_dir, exist_ok=True)
    if filename_prefix is not None:
        img.save(f"{output_dir}/{'--'.join(item[key] for key in filename_prefix)}--{get_image_filename_from_id(item['id'])}")
    else:
        img.save(f"{output_dir}/{get_image_filename_from_id(item['id'])}")

def save_dataset_images(input_path, output_dir, max_workers=4, filename_prefix=None):
    """
    Given a dataset of the format: [{'id': <uuid>, 'task_image': <base64>}, ...], save the images to the output directory.
    """
    with open(input_path, "r") as f:
        data = json.load(f)

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        list(tqdm(executor.map(lambda item: save_image(item, output_dir, filename_prefix), data), total=len(data)))

    print(f"{len(data)} Images saved to {output_dir}")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--max_workers", type=int, default=8, help="Number of parallel workers")
    parser.add_argument("--filename_prefix", type=str, nargs='+', default=None, help="List of keys for the filename, e.g., <key1> <key2> ...")
    args = parser.parse_args()

    save_dataset_images(args.input_path, args.output_dir, args.max_workers, args.filename_prefix)
    print(f"Input path: {args.input_path}")
    print(f"Output path: {args.output_dir}")