import pandas as pd
import ast
from itertools import combinations
from tqdm import tqdm
import csv

# === Config ===
input_csv = "../LVIS/unique/lvis_val_annotations_unique.csv"
output_csv = "lvis_pairwise_spatial_relations.csv"
min_area_ratio = 0.002  # 0.2%

# === Load CSV and parse bbox column ===
df = pd.read_csv(input_csv)
df["bbox"] = df["bbox"].apply(ast.literal_eval)

# === Group by image ===
grouped = df.groupby("image_url")

# === Helpers ===
def get_center(bbox):
    x, y, w, h = bbox
    return (x + w / 2, y + h / 2)

def get_area(bbox):
    _, _, w, h = bbox
    return w * h

def bboxes_overlap(b1, b2):
    x1, y1, w1, h1 = b1
    x2, y2, w2, h2 = b2

    # Define box edges
    x1_min, x1_max = x1, x1 + w1
    y1_min, y1_max = y1, y1 + h1
    x2_min, x2_max = x2, x2 + w2
    y2_min, y2_max = y2, y2 + h2

    # Check if there is any overlap in both x and y directions
    x_overlap = x1_min < x2_max and x2_min < x1_max
    y_overlap = y1_min < y2_max and y2_min < y1_max

    return x_overlap or y_overlap

def get_spatial_relation(c1, c2):
    dx = c2[0] - c1[0]
    dy = c2[1] - c1[1]
    if dx == 0 or dy == 0:
        return None
    if dx < 0 and dy < 0:
        return "left up"
    elif dx > 0 and dy < 0:
        return "right up"
    elif dx < 0 and dy > 0:
        return "left bottom"
    elif dx > 0 and dy > 0:
        return "right bottom"
    return None

# === Output setup ===
fieldnames = [
    "image_url", "arche_object", "target_object",
    "arche_bbox", "target_bbox",
    "arche_center", "target_center",
    "spatial_answer"
]

total_rows = 0
total_combinations = 0
after_bbox_filter = 0
after_area_filter = 0
after_center_filter = 0

with open(output_csv, "w", newline="") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()

    for image_url, group in tqdm(grouped, desc="🔍 Processing images"):
        # === Filter out duplicate-category objects ===
        category_counts = group["name"].value_counts()
        unique_objects = group[group["name"].isin(category_counts[category_counts == 1].index)]

        if len(unique_objects) < 2:
            continue

        objects = unique_objects.to_dict("records")

        # Estimate image size from max x+w and y+h
        image_width = max([bbox[0] + bbox[2] for bbox in group["bbox"]])
        image_height = max([bbox[1] + bbox[3] for bbox in group["bbox"]])
        image_area = image_width * image_height

        for obj1, obj2 in combinations(objects, 2):
            total_combinations += 1

            for arche, target in [(obj1, obj2), (obj2, obj1)]:
                bbox1 = arche["bbox"]
                bbox2 = target["bbox"]

                if bboxes_overlap(bbox1, bbox2):
                    continue
                after_bbox_filter += 1

                area1 = get_area(bbox1)
                area2 = get_area(bbox2)
                if area1 < min_area_ratio * image_area or area2 < min_area_ratio * image_area:
                    continue
                after_area_filter += 1

                center1 = get_center(bbox1)
                center2 = get_center(bbox2)
                spatial = get_spatial_relation(center1, center2)
                if spatial is None:
                    continue
                after_center_filter += 1

                writer.writerow({
                    "image_url": image_url,
                    "arche_object": arche["name"],
                    "target_object": target["name"],
                    "arche_bbox": bbox1,
                    "target_bbox": bbox2,
                    "arche_center": center1,
                    "target_center": center2,
                    "spatial_answer": spatial
                })
                total_rows += 1

# === Summary ===
print(f"\n✅ Done. Total rows written: {total_rows}")
print(f"🔢 Total object pairs before filtering: {total_combinations}")
print(f"✅ After bbox overlap filter: {after_bbox_filter}")
print(f"✅ After bbox area filter: {after_area_filter}")
print(f"✅ After center dx/dy filter: {after_center_filter}")