import os
import cv2
import numpy as np
import pandas as pd
from ultralytics import YOLO

def compute_iou(boxA, boxB):
    """Compute the Intersection over Union (IoU) between two boxes.
    Each box is in the format [x_min, y_min, x_max, y_max]."""
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    
    interW = max(0, xB - xA)
    interH = max(0, yB - yA)
    interArea = interW * interH
    
    boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
    
    unionArea = boxAArea + boxBArea - interArea
    if unionArea == 0:
        return 0.0
    return interArea / unionArea

def yolo_to_box(yolo_label, img_width, img_height):
    """
    Convert a YOLO-format label [class, x_center, y_center, width, height] (normalized)
    to an absolute bounding box [x_min, y_min, x_max, y_max].
    """
    _, x_center, y_center, w, h = yolo_label
    x_center *= img_width
    y_center *= img_height
    w *= img_width
    h *= img_height
    x_min = x_center - w / 2
    y_min = y_center - h / 2
    x_max = x_center + w / 2
    y_max = y_center + h / 2
    return [x_min, y_min, x_max, y_max]

def filter_overlapping_boxes(boxes, iou_threshold=0.0):
    """
    Filters out overlapping predicted boxes: if two boxes overlap (IoU > iou_threshold),
    only the larger box (by area) is kept.
    """
    # Sort boxes by area in descending order.
    boxes = sorted(boxes, key=lambda box: (box[2]-box[0])*(box[3]-box[1]), reverse=True)
    filtered = []
    for box in boxes:
        keep = True
        for kept_box in filtered:
            if compute_iou(box, kept_box) > iou_threshold:
                keep = False
                break
        if keep:
            filtered.append(box)
    return filtered

def get_output_paths(input_path):
    """
    Given an input bilateral scan path, returns the corresponding output paths for the
    left and right knee crops. The output paths are created by replacing "OAI" with
    "YOLO_OAI" and appending "_L" or "_R" before the file extension.
    """
    # Replace "OAI" with "YOLO_OAI" in the full path.
    out_path = input_path.replace("OAI", "YOLO_OAI")
    base, ext = os.path.splitext(out_path)
    left_path = base + "_L" + ext
    right_path = base + "_R" + ext
    return left_path, right_path

def main():
    csv_path = "/home/acc/Treatment_Modeling/Temporal-Treatment-Modeling/data/pairs_dataset/pairs-dataset.csv"
    df = pd.read_csv(csv_path)
    
    # Collect unique image paths from both columns.
    image_paths = set(df["image_earlier"].tolist() + df["image_later"].tolist())
    print(f"Found {len(image_paths)} unique images.")

    image_paths = ["/local2/acc/OAI/48MonthImages/image03/48m/6.C.1/9235666/20090206/03008003_1x1.jpg"]
    
    # Load the YOLO model.
    model_path = "knee_detector_yolo11xb32.pt"
    model = YOLO(model_path)
    
    valid_classes = {0, 1, 2, 3, 4}
    processed_count = 0
    skipped_count = 0  # Already processed images.
    skipped_due_to_boxes = 0  # Images skipped due to not exactly 2 valid boxes.
    
    for img_path in image_paths:
        left_out, right_out = get_output_paths(img_path)
        
        # Check if already processed.
        if os.path.exists(left_out) or os.path.exists(right_out):
            print(f"Skipping already processed image: {img_path}")
            skipped_count += 1
            continue
        
        # Read the original bilateral scan.
        img = cv2.imread(img_path)
        if img is None:
            print(f"Failed to load image: {img_path}")
            continue
        height, width = img.shape[:2]
        
        # Run model inference.
        results = model.predict(img_path, conf=0.25)
        pred_boxes = []
        if len(results) > 0 and len(results[0].boxes) > 0:
            # Retrieve predicted boxes and class labels.
            pred_boxes_data = results[0].boxes.xyxy.cpu().numpy()
            pred_classes = results[0].boxes.cls.cpu().numpy()
            for box, cls in zip(pred_boxes_data, pred_classes):
                if int(cls) in valid_classes:
                    pred_boxes.append(box.tolist())
                    
        # Filter out overlapping boxes.
        pred_boxes = filter_overlapping_boxes(pred_boxes, iou_threshold=0.0)
        # If more than 2 boxes remain, keep the 2 with largest area.
        if len(pred_boxes) > 2:
            pred_boxes = sorted(pred_boxes, key=lambda box: (box[2]-box[0])*(box[3]-box[1]), reverse=True)[:2]
        # Only proceed if exactly 2 boxes remain.
        if len(pred_boxes) != 2:
            print(f"Image {img_path} skipped: does not have exactly 2 valid boxes after filtering (found {len(pred_boxes)}).")
            skipped_due_to_boxes += 1
            continue
        
        # Determine left vs right knee by sorting the boxes by the x-coordinate of their center.
        def box_center_x(box):
            return (box[0] + box[2]) / 2
        pred_boxes = sorted(pred_boxes, key=box_center_x)
        right_box, left_box = pred_boxes
        
        # Convert box coordinates to integers.
        left_box = [int(round(coord)) for coord in left_box]
        right_box = [int(round(coord)) for coord in right_box]
        
        # Crop out the left and right knees.
        left_crop = img[left_box[1]:left_box[3], left_box[0]:left_box[2]]
        right_crop = img[right_box[1]:right_box[3], right_box[0]:right_box[2]]
        
        # Ensure output directories exist.
        for out_path in (left_out, right_out):
            os.makedirs(os.path.dirname(out_path), exist_ok=True)
        
        # Save the cropped images.
        cv2.imwrite(left_out, left_crop)
        cv2.imwrite(right_out, right_crop)
        print(f"Processed {img_path} ->\n    Left: {left_out}\n    Right: {right_out}")
        processed_count += 1

    print(f"\nTotal images processed: {processed_count}")
    print(f"Total images skipped (already processed): {skipped_count}")
    print(f"Total images skipped due to not exactly 2 valid boxes: {skipped_due_to_boxes}")

if __name__ == "__main__":
    main()
