#!/usr/bin/env python3
"""
Medical image region detection script
Detect regions using Grounding DINO based on positive_caption
Retain at most 4 most reliable detection boxes
"""

import os
import json
import cv2
import torch
import random
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from groundingdino.util.inference import load_model, load_image, predict


def main():
    # Configuration parameters
    dataset_root = "/root/autodl-tmp/dataset"
    output_dir = "./DINO_result"
    config_path = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
    checkpoint_path = "weights/groundingdino_swint_ogc.pth"
    samples_per_folder = 20
    max_boxes = 4

    print("Loading Grounding DINO model...")
    model = load_model(config_path, checkpoint_path)
    print("Model loaded successfully!")

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Get all subfolders
    subfolders = [f for f in Path(dataset_root).iterdir() if f.is_dir()]
    print(f"Found {len(subfolders)} folders")

    total_processed = 0
    total_detected = 0

    for subfolder in subfolders:
        print(f"\nProcessing: {subfolder.name}")

        # Load English report data
        jsonl_file = subfolder / f"{subfolder.name}_en.jsonl"
        if not jsonl_file.exists():
            print(f"JSONL file not found: {jsonl_file}")
            continue

        # Read data
        data = []
        with open(jsonl_file, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data.append(json.loads(line.strip()))
                except:
                    continue

        if not data:
            print(f"No data in {jsonl_file}")
            continue

        # Randomly select samples
        samples = random.sample(data, min(samples_per_folder, len(data)))

        # Create output subdirectory
        output_subfolder = Path(output_dir) / subfolder.name
        output_subfolder.mkdir(exist_ok=True)

        folder_detected = 0

        for item in samples:
            image_name = item.get('image', '')
            positive_caption = item.get('positive_caption', '')

            if not image_name or not positive_caption:
                continue

            # Image path
            image_path = subfolder / "images" / image_name
            if not image_path.exists():
                continue

            try:
                # Load image
                image_source, image = load_image(str(image_path))

                # Clean text prompt
                text_prompt = positive_caption.replace(', ', '. ').replace('; ', '. ')

                # Prediction
                boxes, logits, phrases = predict(
                    model=model,
                    image=image,
                    caption=text_prompt,
                    box_threshold=0.3,
                    text_threshold=0.25
                )

                # Limit number of boxes, retain top 4 with highest confidence
                if len(boxes) > max_boxes:
                    confidence_scores = logits.max(dim=1)[0]
                    top_indices = torch.argsort(confidence_scores, descending=True)[:max_boxes]
                    boxes = boxes[top_indices]
                    logits = logits[top_indices]
                    phrases = [phrases[i] for i in top_indices]

                # Visualization
                image_rgb = cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB)
                height, width = image_rgb.shape[:2]

                fig, ax = plt.subplots(1, 1, figsize=(12, 8))
                ax.imshow(image_rgb)

                # Draw detection boxes
                colors = ['red', 'blue', 'green', 'yellow']
                for i, (box, phrase) in enumerate(zip(boxes, phrases)):
                    x1, y1, x2, y2 = box.cpu().numpy()
                    x1_px, y1_px = int(x1 * width), int(y1 * height)
                    x2_px, y2_px = int(x2 * width), int(y2 * height)

                    rect = patches.Rectangle(
                        (x1_px, y1_px), x2_px - x1_px, y2_px - y1_px,
                        linewidth=3, edgecolor=colors[i], facecolor='none'
                    )
                    ax.add_patch(rect)

                    confidence = logits[i].max().item()
                    ax.text(x1_px, y1_px - 10, f"{phrase} ({confidence:.2f})",
                            fontsize=10, color=colors[i],
                            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

                ax.set_title(f"Detected {len(boxes)} regions", fontsize=14)
                ax.axis('off')

                # Save
                save_path = output_subfolder / f"{image_name.split('.')[0]}_result.jpg"
                plt.tight_layout()
                plt.savefig(save_path, dpi=150, bbox_inches='tight')
                plt.close()

                total_processed += 1
                if len(boxes) > 0:
                    folder_detected += 1
                    total_detected += 1
                    print(f"  ✓ {image_name}: {len(boxes)} regions")
                else:
                    print(f"  ○ {image_name}: No regions")

            except Exception as e:
                print(f"  ✗ Error: {image_name}: {str(e)}")
                continue

        print(f"{subfolder.name}: {folder_detected}/{len(samples)} with detections")

    print(f"\nSummary:")
    print(f"Total processed: {total_processed}")
    print(f"With detections: {total_detected}")
    print(f"Success rate: {total_detected / total_processed * 100:.1f}%")
    print(f"Results in: {output_dir}")


if __name__ == "__main__":
    main()