import os
import cv2
import numpy as np
from ultralytics import YOLO

def compute_iou(boxA, boxB):
    """Compute the Intersection over Union (IoU) between two boxes.
       Each box is [x_min, y_min, x_max, y_max]."""
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    
    interW = max(0, xB - xA)
    interH = max(0, yB - yA)
    interArea = interW * interH
    
    boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
    
    unionArea = boxAArea + boxBArea - interArea
    if unionArea == 0:
        return 0.0
    return interArea / unionArea

def yolo_to_box(yolo_label, img_width, img_height):
    """
    Convert a YOLO-format label [class, x_center, y_center, width, height] (normalized)
    to an absolute bounding box [x_min, y_min, x_max, y_max].
    """
    _, x_center, y_center, w, h = yolo_label
    x_center *= img_width
    y_center *= img_height
    w *= img_width
    h *= img_height
    x_min = x_center - w / 2
    y_min = y_center - h / 2
    x_max = x_center + w / 2
    y_max = y_center + h / 2
    return [x_min, y_min, x_max, y_max]

def filter_overlapping_boxes(boxes, iou_threshold=0.0):
    """
    Filters out overlapping predicted boxes: if two boxes overlap (IoU > iou_threshold),
    only the larger box (by area) is kept.
    """
    # Sort boxes by area in descending order.
    boxes = sorted(boxes, key=lambda box: (box[2]-box[0])*(box[3]-box[1]), reverse=True)
    filtered = []
    for box in boxes:
        keep = True
        for kept_box in filtered:
            if compute_iou(box, kept_box) > iou_threshold:
                keep = False
                break
        if keep:
            filtered.append(box)
    return filtered

def match_boxes(gt_boxes, pred_boxes):
    """
    For images with exactly two boxes, match the ground truth and prediction boxes
    in the order that maximizes the overall IoU.
    """
    pairing1 = [(gt_boxes[0], pred_boxes[0]), (gt_boxes[1], pred_boxes[1])]
    pairing2 = [(gt_boxes[0], pred_boxes[1]), (gt_boxes[1], pred_boxes[0])]
    sum1 = compute_iou(gt_boxes[0], pred_boxes[0]) + compute_iou(gt_boxes[1], pred_boxes[1])
    sum2 = compute_iou(gt_boxes[0], pred_boxes[1]) + compute_iou(gt_boxes[1], pred_boxes[0])
    return pairing1 if sum1 >= sum2 else pairing2

def main():
    test_images_dir = "/local2/acc/OAI/Knee_Joint/KneeXrayData/YOLO/images/test"
    test_labels_dir = "/local2/acc/OAI/Knee_Joint/KneeXrayData/YOLO/labels/test"
    model_path = "knee_detector_yolo11xb32.pt"

    # Load the trained YOLO model.
    model = YOLO(model_path)

    image_files = sorted([os.path.join(test_images_dir, f)
                          for f in os.listdir(test_images_dir) if f.endswith('.png')])

    iou_list = []             # To store per-image average IoU.
    images_not_two_boxes = 0  # Count images that do not yield exactly 2 valid predicted boxes.

    valid_classes = {0, 1, 2, 3, 4}

    for image_file in image_files:
        img = cv2.imread(image_file)
        if img is None:
            continue
        height, width = img.shape[:2]

        # Run model inference.
        results = model.predict(image_file, conf=0.25)
        pred_boxes = []
        if len(results) > 0 and len(results[0].boxes) > 0:
            pred_boxes_data = results[0].boxes.xyxy.cpu().numpy()
            pred_classes = results[0].boxes.cls.cpu().numpy()
            # Only include boxes with valid classes.
            for box, cls in zip(pred_boxes_data, pred_classes):
                if int(cls) in valid_classes:
                    pred_boxes.append(box.tolist())

        # Filter out overlapping predicted boxes.
        pred_boxes = filter_overlapping_boxes(pred_boxes, iou_threshold=0.0)
        # If more than 2 boxes remain, choose the 2 with the largest area.
        if len(pred_boxes) > 2:
            pred_boxes = sorted(pred_boxes, key=lambda box: (box[2]-box[0])*(box[3]-box[1]), reverse=True)[:2]
        # If not exactly 2 boxes remain, count the image as invalid.
        if len(pred_boxes) != 2:
            images_not_two_boxes += 1
            continue

        # Load ground truth boxes.
        base_name = os.path.splitext(os.path.basename(image_file))[0]
        label_file = os.path.join(test_labels_dir, base_name + ".txt")
        gt_boxes = []
        if os.path.exists(label_file):
            with open(label_file, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) != 5:
                        continue
                    yolo_label = list(map(float, parts))
                    box = yolo_to_box(yolo_label, width, height)
                    gt_boxes.append(box)
        else:
            continue

        # Only consider images with exactly 2 ground truth boxes.
        if len(gt_boxes) != 2:
            images_not_two_boxes += 1
            continue

        # Match the two predicted boxes with the two ground truth boxes.
        matches = match_boxes(gt_boxes, pred_boxes)
        ious = [compute_iou(gt, pred) for gt, pred in matches]
        avg_image_iou = np.mean(ious)
        iou_list.append(avg_image_iou)

    overall_avg_iou = np.mean(iou_list) if iou_list else 0.0

    print(f"Average IoU (images with 2 valid boxes): {overall_avg_iou:.4f}")
    print(f"Number of images without exactly 2 valid boxes: {images_not_two_boxes}")

if __name__ == "__main__":
    main()
