# Copyright (c) 2024 ByteDance. All Rights Reserved.
import os
import json
import random

from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes

from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import BoxMode


def load_cocodetectiontsv_dicts(image_root, json_file, image_info="/home/dataset/objects365/objs_coco_train.json"):

    
    image_info = os.path.join(image_root, image_info)
    
    # image info
    with open(image_info, 'r') as f:
        info = json.load(f)["images"]
        info_dict = {item["id"]: item for item in info}
        
    # annotations
    with open(json_file, 'r') as f:
        anns = json.load(f)

    dataset_dicts = []
        
    for image_id, annotations in anns.items():
        record = info_dict[image_id]
        record["file_name"] = os.path.join(image_root, record["file_name"])
        record["image_id"] = int(image_id)
        record["iscrowd"] = 0

        objs = []
        for _, anno in annotations.items():
            if "des_mllm" not in anno:
                continue
            obj = {"bbox": anno["rect"], "bbox_mode": BoxMode.XYWH_ABS}
            obj["category_id"] = 0
            obj["object_description"] = random.choice(anno["des_mllm"])
            objs.append(obj)
            
            if len(objs) == 20:
                break
            
        record["annotations"] = objs
        
        if len(record["annotations"]) == 0:
            continue
        record["task"] = "vg"
        record["dataset_name"] = "objg"
        dataset_dicts.append(record)
    return dataset_dicts


def register_objg_dataset(name, image_root, json_file, **kwargs):
    DatasetCatalog.register(name, lambda: load_cocodetectiontsv_dicts(image_root, json_file))
    MetadataCatalog.get(name).set(
        image_root=image_root,
        evaluator_type="coco",
    )
    

_PREDEFINED_SPLITS_OBJG = {
    "objects_llama": ("objects365/objg", "objects365/llava_one_vision_single_som_split0_1_llama_parsed_10k.json"),
    "objects_gpt": ("objects365/objg", "objects365/llava_one_vision_single_som_split0_1_gpt-4o-2024-08-06_parsed_10k.json"),
}


def register_objg(root):
    for key, (image_root, json_file) in _PREDEFINED_SPLITS_OBJG.items():
        # Assume pre-defined datasets live in `./datasets`.
        register_objg_dataset(
            key,
            os.path.join(root, image_root),
            os.path.join(root, json_file) if "://" not in json_file else json_file,
            #dataset_name_in_dict="phrase"
        )