#%%
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import ast
import os
import requests
from io import BytesIO
import random

# === Config ===
csv_path = "ran50_balanced_target_samples.csv"
# csv_path = "ran_min50_max200_balanced_target_samples.csv"
num_to_plot = 5
random.seed(55)

# Base paths
object365_base = "/fs/scratch/PAS2099/Lemeng/object365/"
nyu_base = "/fs/scratch/PAS2099/Lemeng/NYU_depth/nyu_rgb_images/"

# === Load CSV ===
df = pd.read_csv(csv_path)

# === Count unique labels in 'unified_target'
num_unique_labels = df["unified_target"].nunique()
unique_labels = df["unified_target"].unique()

# === Print result
print(f"✅ Number of unique unified_target labels: {num_unique_labels}")
sampled_rows = df.sample(n=min(num_to_plot, len(df)), random_state=55)

for _, row in sampled_rows.iterrows():
    source = str(row["source"]).lower()
    image_path = None

    try:
        if "object365" in source:
            image_path = os.path.join(object365_base, row["image_path"])
            img = Image.open(image_path).convert("RGB")

        elif "nyu" in source:
            image_path = os.path.join(nyu_base, row["image_path"])
            img = Image.open(image_path).convert("RGB")
            img = img.transpose(Image.Transpose.ROTATE_270)

        elif "lvis" in source:
            image_url = row["image_url"]
            if image_url.startswith("http"):
                response = requests.get(image_url, timeout=10)
                img = Image.open(BytesIO(response.content)).convert("RGB")
            else:
                image_path = image_url  # Treat as local file
                img = Image.open(image_path).convert("RGB")
        else:
            print("❌ Unknown source format")
            continue

    except Exception as e:
        print(f"❌ Failed to load image: {image_path if image_path else image_url} — {e}")
        continue

    # Parse bounding boxes
    try:
        arche_bbox = ast.literal_eval(row["arche_bbox"])
        target_bbox = ast.literal_eval(row["target_bbox"])
    except:
        print("❌ Failed to parse bboxes")
        continue

    # Plot
    fig, ax = plt.subplots(1, figsize=(8, 6))
    ax.imshow(img)

    # Arche (red)
    x, y, w, h = arche_bbox
    ax.add_patch(patches.Rectangle((x, y), w, h, linewidth=2, edgecolor="red", facecolor='none'))
    ax.text(x, y - 5, f"Arche: {row['arche_object']}", color="red", fontsize=12)

    # Target (blue)
    x, y, w, h = target_bbox
    ax.add_patch(patches.Rectangle((x, y), w, h, linewidth=2, edgecolor="blue", facecolor='none'))
    ax.text(x, y - 5, f"Target: {row['target_object']}", color="blue", fontsize=12)

    # Title
    ax.set_title(f"Spatial: {row['spatial_answer']} | Source: {source}")
    ax.axis("off")
    plt.tight_layout()
    plt.show()

# %%
