#%%
import h5py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import ast
import random
from PIL import Image

# === Config ===
csv_path = "nyu_strictly_nonoverlapping_pairs.csv"
mat_path = "../NYU_depth/nyu_metadata/nyu_depth_v2_labeled.mat"
image_root = "/fs/scratch/PAS2099/Lemeng/NYU_depth/nyu_rgb_images/"
num_to_plot = 5
random.seed(42)

# === Load CSV ===
df = pd.read_csv(csv_path)
sampled_rows = df.sample(n=min(num_to_plot, len(df)), random_state=42)

# === Load .mat data once ===
with h5py.File(mat_path, "r") as f:
    labels_all = f["labels"]
    instances_all = f["instances"]

    for _, row in sampled_rows.iterrows():
        image_id = int(row["image_id"])
        image_path = f"{image_root}/image_{image_id:04d}.png"

        # Load and rotate RGB image (90 degrees clockwise)
        try:
            img = Image.open(image_path).convert("RGB")
            img = img.transpose(Image.Transpose.ROTATE_270)
        except Exception as e:
            print(f"❌ Failed to load or rotate image: {image_path} — {e}")
            continue

        # Parse bounding boxes and centers
        arche_bbox = ast.literal_eval(row["arche_bbox"])
        target_bbox = ast.literal_eval(row["target_bbox"])
        arche_center = ast.literal_eval(row["arche_center"])
        target_center = ast.literal_eval(row["target_center"])

        # Plot image with bboxes
        fig, ax = plt.subplots(1, figsize=(8, 6))
        ax.imshow(img)

        # Draw arche bbox (red)
        y1, y2, x1, x2 = arche_bbox
        rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
                                 linewidth=2, edgecolor="red", facecolor='none')
        ax.add_patch(rect)
        ax.text(x1, y1 - 5, f"Arche: {row['arche_object']}",
                color="red", fontsize=12, weight="bold")

        # Draw target bbox (blue)
        y1, y2, x1, x2 = target_bbox
        rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
                                 linewidth=2, edgecolor="blue", facecolor='none')
        ax.add_patch(rect)
        ax.text(x1, y1 - 5, f"Target: {row['target_object']}",
                color="blue", fontsize=12, weight="bold")

        # Show spatial relation
        ax.set_title(f"🖼 Image ID: {image_id} | 📌 Spatial: {row['spatial_answer']}")
        ax.axis("off")
        plt.tight_layout()
        plt.show()
# %%
