import argparse
import json
import os

import numpy as np
from detectron2.engine import DefaultPredictor
from dm.config import load_detectron, load_img, load_module_config
from dm.dotdict import DotDict
from dm.masks import convert_array_to_rle
from tqdm import tqdm

parser = argparse.ArgumentParser()

parser.add_argument("--use_cuda", action="store_true")
parser.add_argument("--dataset", required=True)
parser.add_argument("--split", required=True)


def generate_masks(dataset: str, split: str, use_cuda: bool, module_info: DotDict):
    # Load detectron
    print("Loading Detectron...")
    detectron: DefaultPredictor = load_detectron(use_cuda, module_info)

    data = {"scenes": []}
    # for each image in the directory
    for filename in tqdm(os.listdir(f"data/{dataset}/images/{split}")):
        if filename == ".DS_Store":
            continue
        # initialise empty list of objects
        objects = []
        # load image
        img: np.ndarray = load_img(dataset, split, filename)
        # predict with detectron
        outputs = detectron(img)

        # for each object in the image
        for obj_mask in outputs["instances"].pred_masks:
            # move mask to cpu
            obj_mask = obj_mask.cpu().numpy()
            # convert the mask to RLE format
            RLE_mask = convert_array_to_rle(obj_mask)
            # create a CLEVR object dict with new RLE mask
            object = {
                "color": False,
                "size": False,
                "material": False,
                "rotation": False,
                "shape": False,
                "3d_coords": False,
                "material": False,
                "pixel_coords": False,
                "mask": RLE_mask,
            }
            objects.append(object)

        scene = {
            "image_index": int(filename.split(".")[0].split("_")[-1]),
            "objects": objects,
            "relationships": {},
            "image_filename": filename,
            "split": dataset,
            "directions": {},
        }
        data["scenes"].append(scene)

    output_dir = f"data/{dataset}/scenes"
    os.makedirs(output_dir, exist_ok=True)
    with open(f"{output_dir}/CLEVR_{split}_masks.json", "w") as f:
        json.dump(data, f)


def main() -> None:
    args = parser.parse_args()
    args.module_config = "detectron.json"
    load_module_config(args)

    generate_masks(args.dataset, args.split, args.use_cuda, args.module_info)


if __name__ == "__main__":
    main()
