import json
import random
import csv
import matplotlib.pyplot as plt
from collections import defaultdict
from itertools import combinations
from tqdm import tqdm

random.seed(42)

# === Load the full JSON ===
with open("../object365/zhiyuan_objv2_val.json", "r") as f:
    full_data = json.load(f)

images = full_data["images"]
annotations = full_data["annotations"]
categories = full_data["categories"]

# === Build mappings ===
image_id_to_size = {img["id"]: (img["width"], img["height"]) for img in images}
image_id_to_file = {}
missing_file_name_images = []

for img in images:
    if "file_name" in img:
        image_id_to_file[img["id"]] = img["file_name"]
    else:
        missing_file_name_images.append(img["id"])

category_id_to_name = {cat["id"]: cat["name"] for cat in categories}

# === Step 1: Filter annotations (iscrowd=0 and isfake=0) ===
filtered_annotations = [
    ann for ann in annotations
    if ann["iscrowd"] == 0 and ann["isfake"] == 0
]

# === Step 2: Group annotations by image ===
image_to_annotations = defaultdict(list)
for ann in filtered_annotations:
    image_to_annotations[ann["image_id"]].append(ann)

# === Step 3: For each image, keep only categories that appear exactly once ===
category_to_images = defaultdict(set)

for img_id, anns in image_to_annotations.items():
    category_count = defaultdict(int)
    for ann in anns:
        category_count[ann["category_id"]] += 1

    for cat_id, count in category_count.items():
        if count == 1:
            category_to_images[cat_id].add(img_id)

# === Step 5: Collect unique images with at least one single-instance category ===
unique_single_instance_images = set()
for img_set in category_to_images.values():
    unique_single_instance_images.update(img_set)

# === Step 6: Define bbox non-overlap function ===
def is_nonoverlapping(bbox1, bbox2):
    x1_min, y1_min, w1, h1 = bbox1
    x1_max = x1_min + w1
    y1_max = y1_min + h1

    x2_min, y2_min, w2, h2 = bbox2
    x2_max = x2_min + w2
    y2_max = y2_min + h2

    x_nonoverlap = x1_max <= x2_min or x2_max <= x1_min
    y_nonoverlap = y1_max <= y2_min or y2_max <= y1_min

    return x_nonoverlap and y_nonoverlap

# === Step 7: Find valid non-overlapping object pairs ===
valid_pairs = []
image_pair_counts = defaultdict(int)

for img_id in tqdm(unique_single_instance_images, desc="🔍 Checking object pairs per image"):
    anns = image_to_annotations.get(img_id, [])
    
    category_freq = defaultdict(int)
    for ann in anns:
        category_freq[ann["category_id"]] += 1

    valid_anns = [ann for ann in anns if category_freq[ann["category_id"]] == 1]

    for ann1, ann2 in combinations(valid_anns, 2):
        if ann1["category_id"] == ann2["category_id"]:
            continue

        if is_nonoverlapping(ann1["bbox"], ann2["bbox"]):
            valid_pairs.append((img_id, ann1, ann2))
            image_pair_counts[img_id] += 1

# === Output CSV with both orderings and bbox size filtering ===
output_csv = "valid_nonoverlapping_pairs.csv"
min_area_ratio = 0.002  # 0.2%

def get_center(bbox):
    x, y, w, h = bbox
    return (x + w / 2, y + h / 2)

def get_area_ratio(bbox, img_size):
    _, _, w, h = bbox
    return (w * h) / (img_size[0] * img_size[1])

def get_spatial_relation(arche_center, target_center):
    ax, ay = arche_center
    tx, ty = target_center

    if tx < ax and ty < ay:
        return "left up"
    elif tx > ax and ty < ay:
        return "right up"
    elif tx < ax and ty > ay:
        return "left bottom"
    elif tx > ax and ty > ay:
        return "right bottom"
    else:
        return "ambiguous"

rows_written = 0

with open(output_csv, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow([
        "image_path", "arche_object", "target_object",
        "arche_bbox", "target_bbox", "arche_center",
        "target_center", "spatial_answer"
    ])

    for img_id, ann1, ann2 in valid_pairs:
        image_path = image_id_to_file.get(img_id)
        img_size = image_id_to_size.get(img_id)

        if not image_path or not img_size:
            continue

        for arche, target in [(ann1, ann2), (ann2, ann1)]:
            bbox1 = arche["bbox"]
            bbox2 = target["bbox"]

            area1 = get_area_ratio(bbox1, img_size)
            area2 = get_area_ratio(bbox2, img_size)

            if area1 < min_area_ratio or area2 < min_area_ratio:
                continue

            center1 = get_center(bbox1)
            center2 = get_center(bbox2)
            spatial = get_spatial_relation(center1, center2)

            writer.writerow([
                image_path,
                category_id_to_name[arche["category_id"]],
                category_id_to_name[target["category_id"]],
                bbox1,
                bbox2,
                center1,
                center2,
                spatial
            ])
            rows_written += 1

print(f"✅ Saved {rows_written} valid pairs (with both orderings and bbox size ≥ 0.2%) to {output_csv}")
