import os
import json
import logging

from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes

logger = logging.getLogger(__name__)

__all__ = ["load_omnilabel_json", "register_omnilabel"]

  
def _get_builtin_metadata(dataset_name):
    return _get_metadata([])

    raise KeyError("No built-in metadata for dataset {}".format(dataset_name))

def _get_metadata(categories):
    if len(categories) == 0:
        return {}
    id_to_name = {x["id"]: x["name"] for x in categories}
    thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(len(categories))}
    thing_classes = [id_to_name[k] for k in sorted(id_to_name)]
    return {
        "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
        "thing_classes": thing_classes,
    }

def _get_onmi_builtin_metadata(name, metadata, json_file, image_root):
    DatasetCatalog.register(name, lambda: load_omnilabel_json(json_file, image_root))
    MetadataCatalog.get(name).set(
        json_file=json_file, image_root=image_root, evaluator_type="omnilabel", **metadata
    )
    
def load_omnilabel_json(json_file, image_root): 
    """
    Args:
        data (dict): Dataset

    Returns:
        list: A list of records where each record contains a description, its annotations, and matched images.
    """
    with open(json_file, "r") as file:
        data_json = json.load(file)
        
    # Create dictionaries
    image_dict = {img["id"]: {"image_id": img["id"], "file_name": img["file_name"]} for img in data_json["images"]}
    annotation_dict = {
        desc["id"]: [
            {**ann, "bbox": ann["bbox"] if ann.get("anno_info", {}).get("posneg") == "P" else []}
            for ann in data_json["annotations"]
            if desc["id"] in ann["description_ids"]
        ]
        for desc in data_json["descriptions"]
    }
    
    # Build records for Omnilabel valudation set
    records = [
        {
            "image_id": img["image_id"],
            "file_name": os.path.join(image_root, img["file_name"]),
            "description_ids": [desc["id"],],
            "expressions": [desc["text"],],
            #"text_prompt": desc["text"],
            #"captions": desc,
            "annotations": annotation_dict[desc["id"]],
            "bbox_mode": BoxMode.XYWH_ABS,
            "task": "grounding"
            #**({"neg_category_id": desc["neg_description_ids"]} 
            #   if "neg_description_ids" in desc else {}),
        }
        for desc in data_json["descriptions"]
        for img_id in desc.get("image_ids", [])
        if (img := image_dict.get(img_id)) and
        (os.path.exists(os.path.join(image_root, img["file_name"])))
    ]
    return records


_PREDEFINED_SPLITS = {
    "omnilabel": {
        #"omnilabel_o365_val": (
        #    "omnilabel/images",
        #    "omnilabel/annotations/dataset_all_val_v0.1.3_object365.json",
        #    ),
        "omnilabel_coco2_val": (
            "omnilabel/images",
            "omnilabel/annotations/dataset_all_val_v0.1.3_coco.json"
        ),
        #"omnilabel_oid_val": (
        #    "openimages/images",
        #    "omnilabel/annotations/dataset_all_val_v0.1.3_openimagesv5.json",
        #),
        #"omnilabel_val": (
        #    "omnilabel/images",
        #    "omnilabel/annotations/dataset_all_val_v0.1.3.json",
        #),
    },
}


def register_omnilabel(root):
    for dataset_name, datasets in _PREDEFINED_SPLITS.items():
        for key, (image_root, json_file) in datasets.items():
            _get_onmi_builtin_metadata(
                key,
                _get_builtin_metadata(dataset_name),
                os.path.join(root, json_file) if "://" not in json_file else json_file,
                os.path.join(root, image_root),
            )
