import json
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

def compute_iou(box1, box2):
    """box: [x1, y1, x2, y2]"""
    x_left = max(box1[0], box2[0])
    y_top = max(box1[1], box2[1])
    x_right = min(box1[2], box2[2])
    y_bottom = min(box1[3], box2[3])

    if x_right < x_left or y_bottom < y_top:
        return 0.0

    inter_area = (x_right - x_left) * (y_bottom - y_top)
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])

    iou = inter_area / float(box1_area + box2_area - inter_area)
    return iou

gt_file = "/mnt/home/annonymous/iclr2026/yolov11_forestpersons/data_forestpersons_v3/test.json"

pred_file = 'converted_predictions_test_val30.json'

coco_gt = COCO(gt_file)
with open(gt_file) as f:
    gt_data = json.load(f)

# image_id → list of GT annotations (bbox, category_id)
image_to_gts = {}
for ann in gt_data['annotations']:
    img_id = ann['image_id']
    bbox = ann['bbox']  # [x, y, w, h]
    x1, y1, w, h = bbox
    x2, y2 = x1 + w, y1 + h
    entry = {'bbox': [x1, y1, x2, y2], 'category_id': ann['category_id']}
    image_to_gts.setdefault(img_id, []).append(entry)

with open(pred_file) as f:
    preds = json.load(f)

for p in preds:
    img_id = p['image_id']
    pred_bbox = p['bbox']  # [x, y, w, h]
    px1, py1, pw, ph = pred_bbox
    px2, py2 = px1 + pw, py1 + ph
    pred_box = [px1, py1, px2, py2]

    best_iou = 0
    best_cat = 1  # fallback category_id

    for gt in image_to_gts.get(img_id, []):
        iou = compute_iou(pred_box, gt['bbox'])
        if iou > best_iou:
            best_iou = iou
            best_cat = gt['category_id']

    p['category_id'] = best_cat

print("✅ Finished assigning category_ids by IoU match.")

attribute_name = 'pose'
secondary_attribute = 'visible_ratio'
attribute_values = ['standing', 'sitting', 'lying']
visible_ratios = [20, 40, 70, 100]

for attr_value in attribute_values:
    for vis_value in visible_ratios:
        matching_anns = [
            ann for ann in gt_data['annotations']
            if 'attributes' in ann and
               ann['attributes'].get(attribute_name) == attr_value and
               ann['attributes'].get(secondary_attribute) == vis_value
        ]

        img_ids = list(set(ann['image_id'] for ann in matching_anns))
        num_gt_annotations = len(matching_anns)

        filtered_preds = [p for p in preds if p['image_id'] in img_ids]

        print(f'\n=== Evaluating for {attribute_name} = {attr_value}, '
              f'{secondary_attribute} = {vis_value} '
              f'({len(img_ids)} images, {num_gt_annotations} GT annotations, '
              f'{len(filtered_preds)} predictions) ===')

        if num_gt_annotations == 0:
            print("⚠️ Skipping: no matching GT annotations.")
            continue

        coco_pred = coco_gt.loadRes(filtered_preds)
        coco_eval = COCOeval(coco_gt, coco_pred, iouType='bbox')
        coco_eval.params.imgIds = img_ids
        coco_eval.evaluate()
        coco_eval.accumulate()
        coco_eval.summarize()
