import json
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

def compute_iou(box1, box2):
    """box: [x1, y1, x2, y2]"""
    x_left = max(box1[0], box2[0])
    y_top = max(box1[1], box2[1])
    x_right = min(box1[2], box2[2])
    y_bottom = min(box1[3], box2[3])

    if x_right < x_left or y_bottom < y_top:
        return 0.0

    inter_area = (x_right - x_left) * (y_bottom - y_top)
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])

    iou = inter_area / float(box1_area + box2_area - inter_area)
    return iou

gt_file = 'data_forestpersons_v3/test.json'

pred_files = [
    'converted_predictions_test_trained_with_summer.json',
    'converted_predictions_test_trained_with_winter.json',
    'converted_predictions_test_val30.json'
]

coco_gt = COCO(gt_file)
with open(gt_file) as f:
    gt_data = json.load(f)

# image_id → list of GT annotations (bbox, category_id)
image_to_gts = {}
for ann in gt_data['annotations']:
    img_id = ann['image_id']
    bbox = ann['bbox']  # [x, y, w, h]
    x1, y1, w, h = bbox
    x2, y2 = x1 + w, y1 + h
    entry = {'bbox': [x1, y1, x2, y2], 'category_id': ann['category_id']}
    image_to_gts.setdefault(img_id, []).append(entry)


for pred_file in pred_files:

    print(f"\n=== Evaluating {pred_file} ===")
    print(f"\n==============================")

    with open(pred_file) as f:
        preds = json.load(f)

    for p in preds:
        img_id = p['image_id']
        pred_bbox = p['bbox']  # [x, y, w, h]
        px1, py1, pw, ph = pred_bbox
        px2, py2 = px1 + pw, py1 + ph
        pred_box = [px1, py1, px2, py2]

        best_iou = 0
        best_cat = 1  # fallback category_id

        for gt in image_to_gts.get(img_id, []):
            iou = compute_iou(pred_box, gt['bbox'])
            if iou > best_iou:
                best_iou = iou
                best_cat = gt['category_id']

        p['category_id'] = best_cat

    print("✅ Finished assigning category_ids by IoU match.")

    attribute_name = 'season'
    attribute_values = ['summer', 'fall', 'winter']

    for attr_value in attribute_values:
        matching_anns = [
            ann for ann in gt_data['annotations']
            if 'attributes' in ann and
            ann['attributes'].get(attribute_name) == attr_value
        ]

        img_ids = list(set(ann['image_id'] for ann in matching_anns))
        num_gt_annotations = len(matching_anns)

        filtered_preds = [p for p in preds if p['image_id'] in img_ids]

        print(f'\n=== Evaluating for {attribute_name} = {attr_value}, '
            f'({len(img_ids)} images, {num_gt_annotations} GT annotations, '
            f'{len(filtered_preds)} predictions) ===')

        if num_gt_annotations == 0:
            print("⚠️ Skipping: no matching GT annotations.")
            continue

        coco_pred = coco_gt.loadRes(filtered_preds)
        coco_eval = COCOeval(coco_gt, coco_pred, iouType='bbox')
        coco_eval.params.imgIds = img_ids
        coco_eval.evaluate()
        coco_eval.accumulate()
        coco_eval.summarize()