import argparse
import json
import os

from tqdm import tqdm

parser = argparse.ArgumentParser()


parser.add_argument("--dataset", required=True)
parser.add_argument("--split", required=True)


def main(dataset: str, split: str):
    # input
    with open(f"data/{dataset}/scenes/CLEVR_{split}_normalised.json") as f:
        data = json.load(f)

    # split each individual scene into a single json file
    output_dir = f"data/{dataset}/scenes/{split}"
    os.makedirs(output_dir, exist_ok=True)
    for scene in tqdm(data["scenes"]):
        fname = f"{os.path.splitext(scene['image_filename'])[0]}"
        with open(f"{output_dir}/{fname}.json", "w") as f:
            json.dump(scene, f)


if __name__ == "__main__":
    args = parser.parse_args()
    main(args.dataset, args.split)
