import h5py
import numpy as np
import os
import cv2
import pandas as pd
from tqdm import tqdm
import sys

# Paths
mat_path = "../../NYU_depth/nyu_metadata/nyu_depth_v2_labeled.mat"
label_map_txt = "../../NYU_depth/nyu_metadata/nyu_class_labels.txt"
rgb_folder = "../../NYU_depth/nyu_rgb_images"
output_folder = "image_bbox"
csv_output = "bbox_metadata_all.csv"
os.makedirs(output_folder, exist_ok=True)

neglect_label_ids = [
    4, 11, 20, 21, 40, 41, 57, 62, 69, 75, 197, 213, 220, 552, 559,
    719, 724, 861, 867, 868, 869, 872, 891, 642, 657, 658, 659, 573, 874, 892, 52, 212, 143, 186
]

# Load label name map - starting from 1: book
label_id_to_name = {}
with open(label_map_txt, "r") as f:
    for line in f:
        if ":" in line:
            idx, name = line.strip().split(": ", 1)
            label_id_to_name[int(idx)] = name

# Open .mat file once
with h5py.File(mat_path, "r") as f:
    num_images = f["labels"].shape[0] # 1449 images
    rows = []
    new_image_id = 0

    for image_id in tqdm(range(num_images), desc="Processing NYU images"):
        print(f"Processing image {image_id}/{num_images}...")

        label_map = np.array(f["labels"][image_id])
        instance_map = np.array(f["instances"][image_id])
        depth_map = np.array(f["depths"][image_id])

        rgb_path = os.path.join(rgb_folder, f"image_{image_id:04d}.png")
        if not os.path.exists(rgb_path):
            print(f"⚠️ Skipping image {image_id}: RGB not found.")
            continue

        rgb_img = cv2.imread(rgb_path)
        if rgb_img is None:
            print(f"⚠️ Skipping image {image_id}: Failed to load.")
            continue

        # for each object (label id, instance id) draw bbox
        object_keys = set(zip(label_map.flatten(), instance_map.flatten()))
        for label_id, instance_id in object_keys:
            if label_id == 0 or instance_id == 0 or label_id in neglect_label_ids:
                continue

            mask = np.logical_and(label_map == label_id, instance_map == instance_id)

            ys, xs = np.where(mask)
            min_x, max_x = xs.min(), xs.max()
            min_y, max_y = ys.min(), ys.max()

            bbox_width = max_x - min_x + 1
            bbox_height = max_y - min_y + 1
            bbox_area = bbox_width * bbox_height

            image_copy = rgb_img.copy()
            cv2.rectangle(image_copy, (min_x, min_y), (max_x, max_y), (0, 0, 255), 2)

            out_name = f"{new_image_id:06d}.jpg"
            out_path = os.path.join(output_folder, out_name)
            cv2.imwrite(out_path, image_copy)

            min_depth = round(float(np.min(depth_map[mask])), 4)
            label_name = label_id_to_name.get(label_id, "unknown")

            rows.append({
                "new_image_id": new_image_id,
                "old_image_id": image_id,
                "label_id": label_id,
                "instance_id": instance_id,
                "label_name": label_name,
                "depth": min_depth,
                "bbox_width": bbox_width,
                "bbox_height": bbox_height,
                "bbox_area": bbox_area
            })

            new_image_id += 1

# Save CSV 
df = pd.DataFrame(rows)
df.to_csv(csv_output, index=False)

print(f"\nCompleted: {new_image_id} object images saved to '{output_folder}'")
print(f"Metadata written to: {csv_output}")
