import os
import json
import logging
from PIL import Image

from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes

logger = logging.getLogger(__name__)

__all__ = ["load_omnilabel_json", "register_omnilabel"]

  
def _get_builtin_metadata(dataset_name):
    return _get_metadata([])

    raise KeyError("No built-in metadata for dataset {}".format(dataset_name))

def _get_metadata(categories):
    if len(categories) == 0:
        return {}
    id_to_name = {x["id"]: x["name"] for x in categories}
    thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(len(categories))}
    thing_classes = [id_to_name[k] for k in sorted(id_to_name)]
    return {
        "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
        "thing_classes": thing_classes,
    }

def _get_onmi_builtin_metadata(name, metadata, json_file, image_root):
    DatasetCatalog.register(name, lambda: load_omnilabel_json(json_file, image_root))
    MetadataCatalog.get(name).set(
        json_file=json_file, image_root=image_root, evaluator_type="omnilabel", **metadata
    )
    
# def load_omnilabel_json(json_file, image_root): 
#     """
#     Args:
#         data (dict): Dataset

#     Returns:
#         list: A list of records where each record contains a description, its annotations, and matched images.
#     """
#     with open(json_file, "r") as file:
#         data_json = json.load(file)
        
#     # Create dictionaries
#     image_dict = {img["id"]: {"image_id": img["id"], "file_name": img["file_name"]} for img in data_json["images"]}
#     annotation_dict = {
#         desc["id"]: [
#             {**ann, "bbox": ann["bbox"] if ann.get("anno_info", {}).get("posneg") == "P" else []}
#             for ann in data_json["annotations"]
#             if desc["id"] in ann["description_ids"]
#         ]
#         for desc in data_json["descriptions"]
#     }
    
#     # Build records for Omnilabel valudation set
#     records = [
#         {
#             "image_id": img["image_id"],
#             "file_name": os.path.join(image_root, img["file_name"]),
#             "description_ids": [desc["id"],],
#             "expressions": [desc["text"],],
#             #"text_prompt": desc["text"],
#             #"captions": desc,
#             "annotations": annotation_dict[desc["id"]],
#             "bbox_mode": BoxMode.XYWH_ABS,
#             "task": "grounding"
#             #**({"neg_category_id": desc["neg_description_ids"]} 
#             #   if "neg_description_ids" in desc else {}),
#         }
#         for desc in data_json["descriptions"]
#         for img_id in desc.get("image_ids", [])
#         if (img := image_dict.get(img_id)) and
#         (os.path.exists(os.path.join(image_root, img["file_name"])))
#     ]
#     return records

import omnilabeltools as olt

def load_omnilabel_json(json_file, image_root):
    assert isinstance(json_file, str)

    ol = olt.OmniLabel(json_file)
    dataset_dicts = []

    for img_id in ol.image_ids:
        img_sample = ol.get_image_sample(img_id)
        img_path = os.path.join(image_root, img_sample["file_name"])
        
        with Image.open(img_path) as img:
            width, height = img.size
            #print(width, height)
            if width > 2000:
                continue

        dataset_dicts.append({
            "image_id": img_sample["id"],
            "file_name": img_path,
            "height": height,
            "width": width,
            "task": "omnilabel",
            "dataset_name": "omnilabel",
            "inference_obj_descriptions": [od["text"] for od in img_sample["labelspace"]],
            "description_ids": [od["id"] for od in img_sample["labelspace"]],
        })

    return dataset_dicts


_PREDEFINED_SPLITS = {
    "omnilabel": {
        "omnilabel_o365_val": (
           "omnilabel/image",
           "omnilabel/gt/dataset_all_val_v0.1.3_object365.json",
           ),
        "omnilabel_coco_val": (
            "omnilabel/image",
            "omnilabel/gt/dataset_all_val_v0.1.3_coco.json"
        ),
        "omnilabel_oid_val": (
           "omnilabel/image",
           "omnilabel/gt/dataset_all_val_v0.1.3_openimagesv5.json",
        ),
        "omnilabel_val": (
           "omnilabel/image",
           "omnilabel/gt/dataset_all_val_v0.1.3.json",
        ),
    },
}


def register_omnilabel(root):
    for dataset_name, datasets in _PREDEFINED_SPLITS.items():
        for key, (image_root, json_file) in datasets.items():
            _get_onmi_builtin_metadata(
                key,
                _get_builtin_metadata(dataset_name),
                os.path.join(root, json_file) if "://" not in json_file else json_file,
                os.path.join(root, image_root),
            )
