import h5py
import numpy as np
import pandas as pd
from itertools import combinations
from collections import Counter
from tqdm import tqdm

# Set random seed for reproducibility
np.random.seed(42)

# === Config ===
neglect_label_ids = [
    4, 11, 20, 21, 40, 41, 57, 62, 69, 75, 197, 213, 220, 552, 559,
    719, 724, 861, 867, 868, 869, 872, 891, 642, 657, 658, 659, 573, 874, 892, 52, 212, 143, 186
]
min_area_ratio = 0.02  # 2%

# === Load label map ===
label_map = {}
with open("nyu_class_labels.txt", "r") as f:
    for line in f:
        if ":" in line:
            idx, name = line.strip().split(":", 1)
            label_map[int(idx)] = name.strip()

# === Helper functions ===
def boxes_strictly_nonoverlapping(box1, box2):
    y1_min, y1_max, x1_min, x1_max = box1
    y2_min, y2_max, x2_min, x2_max = box2
    x_nonoverlap = x1_max < x2_min or x2_max < x1_min
    y_nonoverlap = y1_max < y2_min or y2_max < y1_min
    return x_nonoverlap and y_nonoverlap

def get_center(box):
    y_min, y_max, x_min, x_max = box
    return ((y_min + y_max) / 2, (x_min + x_max) / 2)

def determine_relation(c1, c2):
    dy = c2[0] - c1[0]
    dx = c2[1] - c1[1]
    vertical = "bottom" if dy > 0 else "up"
    horizontal = "right" if dx > 0 else "left"
    return f"{horizontal} {vertical}"

def get_box_area(box):
    y_min, y_max, x_min, x_max = box
    return (y_max - y_min) * (x_max - x_min)

# === Load .mat file ===
with h5py.File("../NYU_depth/nyu_metadata/nyu_depth_v2_labeled.mat", "r") as f:
    labels_all = f["labels"]
    instances_all = f["instances"]
    num_images = labels_all.shape[0]
    height, width = labels_all.shape[1], labels_all.shape[2]
    image_area = height * width

    results = []

    for image_id in tqdm(range(num_images), desc="🔍 Generating non-overlapping pairs"):
        label_frame = np.rot90(labels_all[image_id], k=-1)
        instance_frame = np.rot90(instances_all[image_id], k=-1)

        object_keys = list(zip(label_frame.flatten(), instance_frame.flatten()))
        object_counts = Counter(object_keys)

        label_instance_counts = Counter()
        for (label, inst), count in object_counts.items():
            if label != 0 and label not in neglect_label_ids and label in label_map:
                label_instance_counts[label] += 1

        valid_labels = [label for label, count in label_instance_counts.items() if count == 1]
        if len(valid_labels) < 2:
            continue

        for obj1, obj2 in combinations(valid_labels, 2):
            for arche, target in [(obj1, obj2), (obj2, obj1)]:
                arche_inst = [inst for (label, inst) in object_counts if label == arche]
                target_inst = [inst for (label, inst) in object_counts if label == target]

                if len(arche_inst) != 1 or len(target_inst) != 1:
                    continue

                arche_inst_id = arche_inst[0]
                target_inst_id = target_inst[0]

                arche_mask = (label_frame == arche) & (instance_frame == arche_inst_id)
                target_mask = (label_frame == target) & (instance_frame == target_inst_id)

                arche_y, arche_x = np.where(arche_mask)
                target_y, target_x = np.where(target_mask)

                if len(arche_y) == 0 or len(target_y) == 0:
                    continue

                arche_box = (int(arche_y.min()), int(arche_y.max()), int(arche_x.min()), int(arche_x.max()))
                target_box = (int(target_y.min()), int(target_y.max()), int(target_x.min()), int(target_x.max()))

                if not boxes_strictly_nonoverlapping(arche_box, target_box):
                    continue

                area1 = get_box_area(arche_box)
                area2 = get_box_area(target_box)

                if area1 < min_area_ratio * image_area or area2 < min_area_ratio * image_area:
                    continue

                center_arche = get_center(arche_box)
                center_target = get_center(target_box)
                relation = determine_relation(center_arche, center_target)

                results.append({
                    "image_id": int(image_id),
                    "arche_object": label_map[arche],
                    "target_object": label_map[target],
                    "arche_bbox": arche_box,
                    "target_bbox": target_box,
                    "arche_center": (round(center_arche[1], 2), round(center_arche[0], 2)),  # (x, y)
                    "target_center": (round(center_target[1], 2), round(center_target[0], 2)),
                    "spatial_answer": relation
                })

# === Save to CSV ===
df = pd.DataFrame(results)
df.to_csv("strictly_nonoverlapping_pairs.csv", index=False)

print(f"✅ Saved {len(results)} rows to strictly_nonoverlapping_pairs.csv")
