import argparse
import json
import os

import cv2 as cv
from dm.extract import extract_feature
from tqdm import tqdm

parser = argparse.ArgumentParser()

parser.add_argument("--dataset", required=True)
parser.add_argument("--split", required=True)


def main(dataset: str, split: str):
    with open(f"data/{dataset}/scenes/CLEVR_{split}_masks.json", "r") as f:
        data = json.load(f)

    new_scenes = []
    for scene in tqdm(data["scenes"]):
        # Read the image
        img_fname = scene["image_filename"]
        img_path = os.path.join(f"data/{dataset}/images/{split}", img_fname)
        img = cv.imread(img_path)

        # create new dict
        stopped = False
        new_objects = []
        for obj in scene["objects"]:
            if "mask" not in obj:
                stopped = True
                break
            RLE_mask = obj["mask"]
            attributes = extract_feature(img, RLE_mask)
            description = {
                "color": obj["color"],
                "size": obj["size"],
                "material": obj["material"],
                "shape": obj["shape"],
            }
            new_obj = {"attributes": attributes, "description": description}

            new_objects.append(new_obj)

        new_scene = {
            "image_index": scene["image_index"],
            "image_filename": img_fname,
            "objects": new_objects,
        }

        if not stopped:
            new_scenes.append(new_scene)

    new_scenes = {"scenes": new_scenes}

    with open(
        os.path.join(f"data/{dataset}/scenes", f"CLEVR_{split}_extracted.json"), "w"
    ) as f:
        json.dump(new_scenes, f)


if __name__ == "__main__":
    args = parser.parse_args()
    main(args.dataset, args.split)
