#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.

import argparse
import json
import numpy as np
import os
from collections import defaultdict
import cv2
import tqdm
from itertools import chain

from detectron2.config import get_cfg
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import Boxes, BoxMode, Instances
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import Visualizer
from detectron2.projects.glee.config import add_glee_config


def create_instances(predictions, image_size):
    ret = Instances(image_size)

    score = np.asarray([x["score"] for x in predictions])
    chosen = (score > args.conf_threshold).nonzero()[0]
    score = score[chosen]
    bbox = np.asarray([predictions[i]["bbox"] for i in chosen]).reshape(-1, 4)
    bbox = BoxMode.convert(bbox, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)

    labels = np.asarray([dataset_id_map(predictions[i]["category_id"]) for i in chosen])

    ret.scores = score
    ret.pred_boxes = Boxes(bbox)
    ret.pred_classes = labels

    try:
        ret.pred_masks = [predictions[i]["segmentation"] for i in chosen]
    except KeyError:
        pass
    return ret

def setup(args):
    cfg = get_cfg()
    add_glee_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.DATALOADER.NUM_WORKERS = 0
    cfg.freeze()
    return cfg

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="A script that visualizes the json predictions from COCO or LVIS dataset."
    )
    parser.add_argument("--config-file", metavar="FILE", help="path to config file")
    parser.add_argument("--input", required=True, help="JSON file produced by the model")
    parser.add_argument("--output", required=True, help="output directory")
    parser.add_argument("--dataset", help="name of the dataset", default="coco_2017_val")
    parser.add_argument("--conf-threshold", default=0.5, type=float, help="confidence threshold")
    args = parser.parse_args()
    logger = setup_logger()
    logger.info("Arguments: " + str(args))
    cfg = setup(args)

    metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])

    logger = setup_logger()

    with PathManager.open(args.input, "r") as f:
        predictions = json.load(f)

    pred_by_image = defaultdict(list)
    #for p in predictions:
    #    #pred_by_image[p["image_id"]].append(p)
    #    pred_by_image[p["category_id"]].append(p)

    #dicts = list(DatasetCatalog.get(args.dataset))
    #metadata = MetadataCatalog.get(args.dataset)

    #data = ('refcoco-unc-val',)
    #data = ('d3_intra_scenario',)
    data = ('omnilabel_coco_val',)
    dicts = list(chain.from_iterable([DatasetCatalog.get(k) for k in data]))
    if hasattr(metadata, "thing_dataset_id_to_contiguous_id"):

        def dataset_id_map(ds_id):
            return metadata.thing_dataset_id_to_contiguous_id[ds_id]

    elif "lvis" in args.dataset:
        # LVIS results are in the same format as COCO results, but have a different
        # mapping from dataset category id to contiguous category id in [0, #categories - 1]
        def dataset_id_map(ds_id):
            return ds_id - 1

    else:
        raise ValueError("Unsupported dataset: {}".format(args.dataset))

    os.makedirs(args.output, exist_ok=True)

    idx = 0

    for dic in tqdm.tqdm(dicts):
        img = cv2.imread(dic["file_name"], cv2.IMREAD_COLOR)[:, :, ::-1]
        basename = os.path.basename(dic["file_name"])
        predictions = create_instances(pred_by_image[dic["image_id"]], img.shape[:2])
        #if "expressions" in dic and len(predictions) > 0:
            #breakpoint()
            #predictions.set("expressions", dic["expressions"])

        # prediction visualization
        vis = Visualizer(img, metadata)
        vis_pred = vis.draw_instance_predictions(predictions).get_image()
        
        # gt visualization
        vis = Visualizer(img, metadata)
        vis_gt = vis.draw_dataset_dict(dic).get_image()
        
        # concatenate [gt, prediction]
        concat = np.concatenate((vis_gt, vis_pred), axis=1)

        # concat expressions
        h, w, _ = concat.shape
        blank_space = np.full((50, w, 3), 255, dtype=np.uint8)  
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.7
        font_thickness = 1
        text_color = (0, 0, 0)  
        y_offset = 20 
        for expression in dic["inference_obj_descriptions"]:
            cv2.putText(blank_space, expression, (10, y_offset), font, font_scale, text_color, font_thickness, cv2.LINE_AA)
            y_offset += 20 
            
        concat = np.concatenate((concat, blank_space), axis=0) 

        # visualization
        output_path = os.path.join(args.output, basename)
        
        while os.path.exists(output_path):
            output_filename = f"{basename}_{idx}.png"
            output_path = os.path.join(args.output, output_filename)
            idx = idx + 1
            
        cv2.imwrite(output_path, concat[:, :, ::-1])