import os
import cv2
import numpy as np
import pandas as pd
from ultralytics import YOLO
import concurrent.futures
import torch

# -----------------------------
# Utility functions (same as before)
# -----------------------------
def compute_iou(boxA, boxB):
    """Compute the Intersection over Union (IoU) between two boxes.
    Each box is [x_min, y_min, x_max, y_max]."""
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    interW = max(0, xB - xA)
    interH = max(0, yB - yA)
    interArea = interW * interH
    boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
    unionArea = boxAArea + boxBArea - interArea
    return interArea / unionArea if unionArea != 0 else 0.0

def yolo_to_box(yolo_label, img_width, img_height):
    """
    Convert a YOLO-format label [class, x_center, y_center, width, height] (normalized)
    to an absolute bounding box [x_min, y_min, x_max, y_max].
    """
    _, x_center, y_center, w, h = yolo_label
    x_center *= img_width
    y_center *= img_height
    w *= img_width
    h *= img_height
    x_min = x_center - w / 2
    y_min = y_center - h / 2
    x_max = x_center + w / 2
    y_max = y_center + h / 2
    return [x_min, y_min, x_max, y_max]

def filter_overlapping_boxes(boxes, iou_threshold=0.0):
    """
    Filters out overlapping predicted boxes: if two boxes overlap (IoU > iou_threshold),
    only the larger box (by area) is kept.
    """
    boxes = sorted(boxes, key=lambda box: (box[2]-box[0])*(box[3]-box[1]), reverse=True)
    filtered = []
    for box in boxes:
        keep = True
        for kept_box in filtered:
            if compute_iou(box, kept_box) > iou_threshold:
                keep = False
                break
        if keep:
            filtered.append(box)
    return filtered

def get_output_paths(input_path):
    """
    Given an input bilateral scan path, returns the corresponding output paths for the
    left and right knee crops. The output paths are created by replacing "OAI" with
    "YOLO_OAI" and appending "_L" or "_R" before the file extension.
    """
    out_path = input_path.replace("OAI", "YOLO_OAI")
    base, ext = os.path.splitext(out_path)
    left_path = base + "_L" + ext
    right_path = base + "_R" + ext
    return left_path, right_path

# -----------------------------
# Worker functions for processing an image.
# -----------------------------
# This global variable will be set in each worker.
global_model = None

def init_worker(model_path, gpu_id):
    """
    Worker initializer: set the active CUDA device and load the YOLO model.
    """
    global global_model
    torch.cuda.set_device(gpu_id)
    # Load the YOLO model without passing a device argument.
    global_model = YOLO(model_path)

def process_image(img_path):
    """
    Process a single image:
      - Check if output crops already exist.
      - Run YOLO inference to get predicted bounding boxes (filtering for valid classes).
      - Filter overlapping boxes and select exactly 2 boxes.
      - Since the original x-ray is mirrored, assign boxes as:
            right_box, left_box = pred_boxes
      - Crop and save the left and right knee images.
    Returns a dict with the image path and processing status.
    """
    valid_classes = {0, 1, 2, 3, 4}
    result = {"img_path": img_path, "status": None}
    left_out, right_out = get_output_paths(img_path)
    
    if os.path.exists(left_out) or os.path.exists(right_out):
        result["status"] = "already processed"
        return result

    img = cv2.imread(img_path)
    if img is None:
        result["status"] = "failed to load"
        return result
    height, width = img.shape[:2]
    
    # Run YOLO inference.
    results = global_model.predict(img_path, conf=0.25)
    pred_boxes = []
    if results and len(results[0].boxes) > 0:
        pred_boxes_data = results[0].boxes.xyxy.cpu().numpy()
        pred_classes = results[0].boxes.cls.cpu().numpy()
        for box, cls in zip(pred_boxes_data, pred_classes):
            if int(cls) in valid_classes:
                pred_boxes.append(box.tolist())
                
    # Filter overlapping boxes.
    pred_boxes = filter_overlapping_boxes(pred_boxes, iou_threshold=0.0)
    if len(pred_boxes) > 2:
        pred_boxes = sorted(pred_boxes, key=lambda box: (box[2]-box[0])*(box[3]-box[1]), reverse=True)[:2]
    if len(pred_boxes) != 2:
        result["status"] = f"skipped due to {len(pred_boxes)} valid boxes"
        return result
    
    # Sort boxes by the x-coordinate of their center.
    def box_center_x(box):
        return (box[0] + box[2]) / 2
    pred_boxes = sorted(pred_boxes, key=box_center_x)
    # Swap the assignment because the original x-ray is mirrored.
    right_box, left_box = pred_boxes
    
    # Convert coordinates to integers.
    left_box = [int(round(c)) for c in left_box]
    right_box = [int(round(c)) for c in right_box]
    
    # Crop the images.
    left_crop = img[left_box[1]:left_box[3], left_box[0]:left_box[2]]
    right_crop = img[right_box[1]:right_box[3], right_box[0]:right_box[2]]
    
    for path in (left_out, right_out):
        os.makedirs(os.path.dirname(path), exist_ok=True)
    cv2.imwrite(left_out, left_crop)
    cv2.imwrite(right_out, right_crop)
    result["status"] = "processed"
    return result

# -----------------------------
# Main function that distributes work across multiple GPUs.
# -----------------------------
def main():
    csv_path = "/home/acc/Treatment_Modeling/Temporal-Treatment-Modeling/data/pairs_dataset/pairs-dataset.csv"
    df = pd.read_csv(csv_path)
    image_paths = list(set(df["image_earlier"].tolist() + df["image_later"].tolist()))
    print(f"Found {len(image_paths)} unique images.")

    model_path = "knee_detector_yolo11xb32.pt"
    # List of GPU IDs to use (adjust based on your available GPUs)
    gpu_ids = [1, 2, 4, 5, 6, 7]
    num_gpus = len(gpu_ids)
    
    # Distribute image paths among GPUs using round-robin.
    tasks_by_gpu = {gpu: [] for gpu in gpu_ids}
    for i, img_path in enumerate(image_paths):
        gpu = gpu_ids[i % num_gpus]
        tasks_by_gpu[gpu].append(img_path)
    
    # Counters for reporting.
    processed_count = 0
    already_processed = 0
    skipped_due_to_boxes = 0
    failed_load = 0
    results_all = []

    # Create one ProcessPoolExecutor per GPU (each with one worker).
    executors = {}
    futures = []
    for gpu in gpu_ids:
        executors[gpu] = concurrent.futures.ProcessPoolExecutor(
            max_workers=1, initializer=init_worker, initargs=(model_path, gpu)
        )
        for img_path in tasks_by_gpu[gpu]:
            futures.append(executors[gpu].submit(process_image, img_path))
    
    # Gather results.
    for future in concurrent.futures.as_completed(futures):
        res = future.result()
        results_all.append(res)
        status = res["status"]
        if status == "processed":
            processed_count += 1
        elif status == "already processed":
            already_processed += 1
        elif status.startswith("skipped due to"):
            skipped_due_to_boxes += 1
        elif status == "failed to load":
            failed_load += 1

    # Shutdown executors.
    for gpu in gpu_ids:
        executors[gpu].shutdown()

    print(f"\nTotal images processed: {processed_count}")
    print(f"Total images skipped (already processed): {already_processed}")
    print(f"Total images skipped due to not exactly 2 valid boxes: {skipped_due_to_boxes}")
    print(f"Total images failed to load: {failed_load}")

if __name__ == "__main__":
    main()
