import json
import os
import sys

def analyze_coco(json_path, score_threshold=0.5):
    with open(json_path, 'r') as f:
        data = json.load(f)

    if isinstance(data, dict):
        annotations = data.get('annotations', [])
        images = data.get('images', [])
        num_images = len(images)
    elif isinstance(data, list):
        annotations = data
        image_ids = set(ann["image_id"] for ann in annotations)
        num_images = len(image_ids)
    else:
        raise ValueError("Input JSON must be either a dict (COCO format) or a list format.")

    high_score_anns = [ann for ann in annotations if ann.get("score", 1.0) >= score_threshold]
    num_high_score_masks = len(high_score_anns)
    avg_obj_per_image = num_high_score_masks / num_images
    print(f"Number of unique images: {num_images}")
    print(f"Number of masks with score > {score_threshold}: {num_high_score_masks}")
    print(f"Number of objects in per image: {avg_obj_per_image:.1f}")

DATASET_ROOT="/data/xxx/datasets"
DATASET_PATH={
    "coco_val2017": "coco/annotations/coco_cls_agnostic_instances_val2017.json",
    "coco20k": "coco/annotations/coco20k_trainval_gt.json",
    "imagenet_val": "imagenet/annotations/imagenet_val_cls_agnostic_gt.json",
    # "kitti": "kitti/annotations/trainval_cls_agnostic.json",
    "voc": "voc/annotations/trainvaltest_2007_cls_agnostic.json",
    "clipart": "clipart/annotations/traintest_clipart_cls_agnostic.json",
    "watercolor": "watercolor/annotations/traintest_watercolor_cls_agnostic.json",
    "comic": "comic/annotations/traintest_comic_cls_agnostic.json",
    "lvis": "coco/annotations/lvis1.0_cocofied_val_cls_agnostic.json",
    "openimages": "openimages-v7/annotations/openimages_val_cls_agnostic.json",
}

if __name__ == "__main__":
    if len(sys.argv) == 2:
        analyze_coco(sys.argv[1])
    elif len(sys.argv) == 3:
        analyze_coco(sys.argv[1], float(sys.argv[2]))
    exit(0)

    # for key, path in DATASET_PATH.items():
    #     print(f'======={key}=======')
    #     analyze_coco(os.path.join(DATASET_ROOT, path))

