import json
import os
import cv2
from pycocotools.coco import COCO

palette = [(255, 0, 0), (255, 255, 0), (0, 255, 0), (0, 255, 255), (0, 0, 255),
         (255, 0, 255), (255, 255, 255)]


# data_root = "./UnSniffer/detection/dataset/JPEGImages"
# save_dir = "./UnSniffer/show"
# prediction_json_file = './json_output_dir/ood/soda_instances_results_unsniffer.json'
# annotation_json_file = './json_output_dir/ood/val.json'

# data_root = "./owod/PROB/data/OWOD/JPEGImages"
# save_dir = "./auto-gd/show"
# prediction_json_file = './auto-gd/json_output_dir/auto-gd_super_prompt.json'
# annotation_json_file = './json_output_dir/val.json'
# known_prediction_thre = 0.3
# unknown_prediction_thre = 0.1
# auto_gd = True

data_root = "./owod/PROB/data/OWOD/JPEGImages"
save_dir = "./owod/PROB/show"
prediction_json_file = './auto-gd/json_output_dir/auto-gd_super_prompt.json'
annotation_json_file = './json_output_dir/val.json'
known_prediction_thre = 0.3
unknown_prediction_thre = 0.1
auto_gd = True

gt_instance = COCO(annotation_json_file)
prediction_instance = gt_instance.loadRes(prediction_json_file)

for index, image_id in enumerate(gt_instance.imgToAnns):
    # import pdb;pdb.set_trace()
    image_details = gt_instance.imgs[image_id]
    if auto_gd: 
        image_file_name = image_details["file_name"].split("/")[-1]
    else:
        image_file_name = image_details["file_name"]
    image_file_path = os.path.join(data_root, image_file_name)
    image = cv2.imread(image_file_path)
    pred_image = cv2.imread(image_file_path)
    cat_ids =  gt_instance.cats
    gt_list_this_img = gt_instance.loadAnns(gt_instance.getAnnIds(imgIds=[image_id]))
    pred_list_this_img = prediction_instance.loadAnns(prediction_instance.getAnnIds(imgIds=[image_id]))
    for gt_res in gt_list_this_img:
        cls_name = gt_res["name"]
        xmin, ymin = int(gt_res["bbox"][0]), int(gt_res["bbox"][1])
        xmax = xmin + int(gt_res["bbox"][2])
        ymax = ymin + int(gt_res["bbox"][3])
        image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), palette[gt_res["category_id"]], 2)
        if gt_res["category_id"] !=6:
            image = cv2.putText(
                image,
                cls_name, (xmin, ymin-2),
                cv2.FONT_HERSHEY_COMPLEX,
                0.5,
                palette[gt_res["category_id"]],
                2)
    for pred_res in pred_list_this_img:
        if auto_gd:
            category_id = pred_res["category_id"]
            category_id = 5 if category_id == 7 else category_id
            category_id = 7 if category_id > 7 else category_id
        else:
            category_id = pred_res["category_id"]
        cls_name = cat_ids[category_id-1]["name"]
        predictied_scores = pred_res['score']
        predictied_labels = category_id
        predictied_boxes = pred_res['bbox']
        if predictied_labels != 7:
            if predictied_scores > known_prediction_thre: 
                xmin, ymin = int(predictied_boxes[0]), int(predictied_boxes[1])
                xmax = xmin + int(predictied_boxes[2])
                ymax = ymin + int(predictied_boxes[3])
                pred_image = cv2.rectangle(pred_image, (xmin, ymin), (xmax, ymax), palette[category_id-1], 2)
                pred_image = cv2.putText(
                    pred_image,
                    cls_name, (xmin, ymin-2),
                    cv2.FONT_HERSHEY_COMPLEX,
                    0.5,
                    palette[category_id-1],
                    2)
        else:
            if predictied_scores > unknown_prediction_thre: 
                xmin, ymin = int(predictied_boxes[0]), int(predictied_boxes[1])
                xmax = xmin + int(predictied_boxes[2])
                ymax = ymin + int(predictied_boxes[3])
                pred_image = cv2.rectangle(pred_image, (xmin, ymin), (xmax, ymax), palette[category_id-1], 2)
                # pred_image = cv2.putText(
                #     pred_image,
                #     cls_name, (xmin, ymin-2),
                #     cv2.FONT_HERSHEY_COMPLEX,
                #     0.5,
                #     palette[pred_res["category_id"]-1],
                #     2)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    if not os.path.exists(save_dir + "/gt"):
        os.mkdir(save_dir + "/gt")
    if not os.path.exists(save_dir + "/prediction"):
        os.mkdir(save_dir + "/prediction")
    cv2.imwrite(save_dir + '/gt/' + image_file_name, image)
    cv2.imwrite(save_dir + '/prediction/' + image_file_name, pred_image)
    

# categories_id = annotation_json["categories"]
# cat_ids = {category_id["id"]: category_id["name"] for category_id in categories_id}

# for image_info in annotation_json["images"]:
#     image_file_name = image_info["file_name"]
#     image_file_path = os.path.join(data_root, image_file_name)
#     image = cv2.imread(image_file_path)
