#%%
import json
from collections import defaultdict

def process(json_path):
    with open(json_path, "r") as f:
        d = json.load(f)

    id_to_path = {
        img["id"]: "/".join(img["coco_url"].split("/")[-2:])
        for img in d["images"]
        if "coco_url" in img
    }

    cat = {c["id"]: c["name"] for c in d["categories"]}
    out = {}

    for a in d["annotations"]:
        img_id = a["image_id"]
        if img_id not in id_to_path:
            continue
        img = id_to_path[img_id]
        name = cat[a["category_id"]]
        if img not in out:
            out[img] = defaultdict(int)
        out[img][name] += 1

    return {k: dict(v) for k, v in out.items()}

train = process("../../counting/LVIS/downloads/lvis_v1_train.json")
val = process("../../counting/LVIS/downloads/lvis_v1_val.json")
combined = {**train, **val}

with open("per_image_object_counts_combined.json", "w") as f:
    json.dump(combined, f, indent=2)

#%%
import json

with open("per_image_object_counts_combined.json", "r") as f:
    data = json.load(f)

filtered_data = {}

for img_path, objects in data.items():
    new_objects = {obj: count for obj, count in objects.items() if count == 1}
    if new_objects:
        filtered_data[img_path] = new_objects

with open("per_image_object_counts_combined_filtered.json", "w") as f:
    json.dump(filtered_data, f, indent=2)


#%%
import json
import matplotlib.pyplot as plt

with open("per_image_object_counts_combined_filtered.json", "r") as f:
    data = json.load(f)

object_to_image_count = {}

for img_path, objects in data.items():
    for obj in objects.keys():
        if obj not in object_to_image_count:
            object_to_image_count[obj] = 0
        object_to_image_count[obj] += 1

sorted_objects = sorted(object_to_image_count.items(), key=lambda x: x[1], reverse=True)
top_objects = sorted_objects[:100]

object_names = [obj for obj, count in top_objects]
image_counts = [count for obj, count in top_objects]

plt.figure(figsize=(18, 6)) 
plt.bar(object_names, image_counts)
plt.xticks(rotation=90, fontsize=9)  
plt.xlabel("Object Name", fontsize=12)
plt.ylabel("Number of Images", fontsize=12)
plt.title("Top 100 Objects by Number of Images", fontsize=14)
plt.tight_layout()
plt.savefig("top100_objects_by_image_count.png")
plt.show()



#%%
import json
from collections import defaultdict

with open("per_image_object_counts_combined_filtered.json", "r") as f:
    data = json.load(f)

result = defaultdict(lambda: defaultdict(lambda: {"count": 0, "image_ids": []}))

for image_path, objects in data.items():
    for obj_name, count in objects.items():
        count_str = str(count)
        result[obj_name][count_str]["count"] += 1
        result[obj_name][count_str]["image_ids"].append(image_path)

with open("object_grouped_by_count_and_image.json", "w") as f:
    json.dump(result, f, indent=2)

#%%
import os
import json
import csv
import random
from tqdm import tqdm

random.seed(42)

def compute_mean_areas(image_ids_with_object, lvis_train, lvis_val, cat_name_to_id):
    mean_areas = []
    for item in tqdm(image_ids_with_object, desc="Calculating areas"):
        img_path = item["image"]
        obj = item["object"]
        count = item["count"]
        image_id = int(img_path.split("/")[-1].replace(".jpg", ""))
        source = lvis_train if "train2017" in img_path else lvis_val
        img_info = next((img for img in source["images"] if img["id"] == image_id), None)
        if not img_info:
            continue
        w, h = img_info["width"], img_info["height"]
        cat_id = cat_name_to_id[obj]
        anns = [ann for ann in source["annotations"] if ann["image_id"] == image_id and ann["category_id"] == cat_id]
        for ann in anns:
            norm_area = ann["area"] / (w * h)
            mean_areas.append({
                "image": img_path,
                "object": obj,
                "count": count,
                "normalized_area": norm_area
            })
    return mean_areas

with open("object_grouped_by_count_and_image.json", "r") as f:
    all_data = json.load(f)

with open("../../counting/LVIS/downloads/lvis_v1_train.json", "r") as f:
    lvis_train = json.load(f)
with open("../../counting/LVIS/downloads/lvis_v1_val.json", "r") as f:
    lvis_val = json.load(f)

cat_train = {c["name"]: c["id"] for c in lvis_train["categories"]}
cat_val = {c["name"]: c["id"] for c in lvis_val["categories"]}
assert cat_train == cat_val
cat_name_to_id = cat_train

used_images = []
for obj, count_dict in tqdm(all_data.items(), desc="Collecting all images"):
    for count_str, info in count_dict.items():
        try:
            all_image_ids = info["image_ids"]
        except KeyError:
            continue
        if len(all_image_ids) > 50:
            sampled_ids = random.sample(all_image_ids, 50)
        else:
            sampled_ids = all_image_ids
        for img_path in sampled_ids:
            used_images.append({
                "image": img_path,
                "object": obj,
                "count": int(count_str)
            })

with open("merged_used_image_list.json", "w") as f:
    json.dump(used_images, f, indent=2)

area_records = compute_mean_areas(used_images, lvis_train, lvis_val, cat_name_to_id)
with open("normalized_area_records.csv", "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=["image", "object", "count", "normalized_area"])
    writer.writeheader()
    for row in area_records:
        writer.writerow(row)

#%%
import csv
from collections import defaultdict

input_file = "normalized_area_records.csv"
output_file = "sorted_object_by_mean_area.csv"

area_sum = defaultdict(float)
count = defaultdict(int)

with open(input_file, "r") as f:
    reader = csv.DictReader(f)
    for row in reader:
        obj = row["object"]
        area = float(row["normalized_area"])
        area_sum[obj] += area
        count[obj] += 1

mean_areas = [(obj, area_sum[obj] / count[obj]) for obj in area_sum]
mean_areas.sort(key=lambda x: x[1], reverse=True)

with open(output_file, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["object", "mean_normalized_area"])
    for obj, mean in mean_areas:
        writer.writerow([obj, f"{mean:.6f}"])

#%%

import os
import json
import csv
import random
from tqdm import tqdm

random.seed(42)

def compute_mean_areas(image_ids_with_object, lvis_train, lvis_val, cat_name_to_id):
    mean_areas = []
    for item in tqdm(image_ids_with_object, desc="Calculating areas"):
        img_path = item["image"]
        obj = item["object"]
        count = item["count"]
        image_id = int(img_path.split("/")[-1].replace(".jpg", ""))
        source = lvis_train if "train2017" in img_path else lvis_val
        img_info = next((img for img in source["images"] if img["id"] == image_id), None)
        if not img_info:
            continue
        w, h = img_info["width"], img_info["height"]
        cat_id = cat_name_to_id[obj]
        anns = [ann for ann in source["annotations"] if ann["image_id"] == image_id and ann["category_id"] == cat_id]
        for ann in anns:
            norm_area = ann["area"] / (w * h)
            mean_areas.append({
                "image": img_path,
                "object": obj,
                "count": count,
                "normalized_area": norm_area
            })
    return mean_areas

with open("object_grouped_by_count_and_image.json", "r") as f:
    all_data = json.load(f)

with open("../../counting/LVIS/downloads/lvis_v1_train.json", "r") as f:
    lvis_train = json.load(f)
with open("../../counting/LVIS/downloads/lvis_v1_val.json", "r") as f:
    lvis_val = json.load(f)

cat_train = {c["name"]: c["id"] for c in lvis_train["categories"]}
cat_val = {c["name"]: c["id"] for c in lvis_val["categories"]}
assert cat_train == cat_val
cat_name_to_id = cat_train

used_images = []
for obj, count_dict in tqdm(all_data.items(), desc="Collecting all images"):
    for count_str, info in count_dict.items():
        try:
            all_image_ids = info["image_ids"]
        except KeyError:
            continue
        if len(all_image_ids) > 50:
            sampled_ids = random.sample(all_image_ids, 50)
        else:
            sampled_ids = all_image_ids
        for img_path in sampled_ids:
            used_images.append({
                "image": img_path,
                "object": obj,
                "count": int(count_str)
            })

with open("merged_used_image_list.json", "w") as f:
    json.dump(used_images, f, indent=2)

area_records = compute_mean_areas(used_images, lvis_train, lvis_val, cat_name_to_id)
with open("normalized_area_records.csv", "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=["image", "object", "count", "normalized_area"])
    writer.writeheader()
    for row in area_records:
        writer.writerow(row)


#%%
import json
import csv

# === Step 1: Load the JSON summary with counts ===
with open("object_grouped_by_count_and_image.json", "r") as f:
    object_summary = json.load(f)

# === Step 2: Load normalized area CSV ===
mean_area_file = "sorted_object_by_mean_area.csv"
object_to_area = {}

with open(mean_area_file, newline='') as csvfile:
    reader = csv.reader(csvfile)
    next(reader)  # Skip header
    for row in reader:
        obj_name = row[0].strip()
        mean_area = float(row[1])
        object_to_area[obj_name] = mean_area

# === Step 3: Combine object name, area, count ===
output_rows = []
for obj, nested_info in object_summary.items():
    for key in nested_info:
        count = nested_info[key].get("count", 0)
        mean_area = object_to_area.get(obj, None)
        if mean_area is not None:
            output_rows.append([obj, mean_area, count])
        else:
            print(f"⚠️ Warning: '{obj}' not found in normalized area CSV")

# === Step 4: Save to CSV ===
output_rows.sort(key=lambda x: x[2], reverse=True)

with open("final_object_stats.csv", "w", newline="") as out_csv:
    writer = csv.writer(out_csv)
    writer.writerow(["object name", "mean_normalized_area", "number_images"])
    writer.writerows(output_rows)

#%%

import os
import json
import random
import cv2
import numpy as np
from PIL import Image, ImageDraw
from tqdm import tqdm

random.seed(42)

# === Your padding function (unaltered) ===
def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

# === Mean color for background ===
processor_means = [
    [0.48145466, 0.4578275, 0.40821073],
    [0.48145466, 0.4578275, 0.40821073],
    [0.485, 0.456, 0.406],
    [0.485, 0.456, 0.406],
    [0.5, 0.5, 0.5],
    [0.48145466, 0.4578275, 0.40821073],
    [0.485, 0.456, 0.406],
    [0.5, 0.5, 0.5],
    [0.5, 0.5, 0.5]
]
mean_array = np.array(processor_means)
average_mean = np.mean(mean_array, axis=0)
background_color = tuple(int(x * 255) for x in average_mean)

# === Paths ===
input_json = "merged_used_image_list.json"
lvis_train_json = "../../counting/LVIS/downloads/lvis_v1_train.json"
lvis_val_json = "../../counting/LVIS/downloads/lvis_v1_val.json"
img_root = "../../counting/LVIS/downloads"
output_dir = "padded_images_with_bbox"
os.makedirs(output_dir, exist_ok=True)

# === Load data ===
with open(input_json, "r") as f:
    used_images = json.load(f)
with open(lvis_train_json, "r") as f:
    lvis_train = json.load(f)
with open(lvis_val_json, "r") as f:
    lvis_val = json.load(f)

id_to_img = {img["id"]: img for img in lvis_train["images"] + lvis_val["images"]}
all_anns = lvis_train["annotations"] + lvis_val["annotations"]
image_to_anns = {}
for ann in all_anns:
    image_to_anns.setdefault(ann["image_id"], []).append(ann)

# === Find single-bbox images ===
single_bbox_items = []
for item in used_images:
    img_id = int(os.path.splitext(os.path.basename(item["image"]))[0])
    anns = image_to_anns.get(img_id, [])
    if len(anns) == 1:
        single_bbox_items.append((item, anns[0]))

selected = random.sample(single_bbox_items, 3)
bbox_records = []

# === Function to update bbox based on padding logic ===
def update_bbox(orig_w, orig_h, bbox, padded_w, padded_h):
    x, y, w, h = bbox
    if orig_w > orig_h:
        offset_x = 0
        offset_y = (orig_w - orig_h) // 2
    elif orig_h > orig_w:
        offset_x = (orig_h - orig_w) // 2
        offset_y = 0
    else:
        offset_x = offset_y = 0
    return [x + offset_x, y + offset_y, w, h]

# === Main loop ===
for item, ann in selected:
    img_rel_path = item["image"]
    img_abs_path = os.path.join(img_root, img_rel_path)
    if not os.path.exists(img_abs_path):
        print(f"[WARN] Missing: {img_abs_path}")
        continue

    bbox = ann["bbox"]
    img_cv = cv2.imread(img_abs_path)
    img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
    orig_h, orig_w = img_rgb.shape[:2]

    pil_img = Image.fromarray(img_rgb)
    draw = ImageDraw.Draw(pil_img)
    x, y, w, h = bbox
    draw.rectangle([x, y, x + w, y + h], outline="red", width=3)

    # padded_img = expand2square(pil_img, background_color)
    padded_img = pil_img
    padded_w, padded_h = padded_img.size
    new_bbox = update_bbox(orig_w, orig_h, bbox, padded_w, padded_h)

    draw_pad = ImageDraw.Draw(padded_img)
    draw_pad.rectangle([new_bbox[0], new_bbox[1], new_bbox[0] + new_bbox[2], new_bbox[1] + new_bbox[3]],
                       outline="red", width=3)

    save_name = os.path.basename(img_abs_path)
    save_path = os.path.join(output_dir, save_name)
    padded_img.save(save_path)

    bbox_records.append({
        "file_name": save_name,
        "bbox": [round(float(v), 2) for v in new_bbox],
        "bbox_format": "COCO [x, y, width, height]",
        "category_id": ann["category_id"],
        "image_id": ann["image_id"],
        "object": item["object"]
    })


# === Save bbox JSON ===
with open(os.path.join(output_dir, "bbox_info.json"), "w") as f:
    json.dump(bbox_records, f, indent=2)

#%%
import json
from collections import defaultdict
from tqdm import tqdm

# === Load files ===
with open("object_grouped_by_count_and_image.json", "r") as f:
    grouped_data = json.load(f)

with open("../../counting/LVIS/downloads/lvis_v1_train.json", "r") as f:
    lvis_train = json.load(f)

with open("../../counting/LVIS/downloads/lvis_v1_val.json", "r") as f:
    lvis_val = json.load(f)

# === Build mappings ===
cat_name_to_id = {cat["name"]: cat["id"] for cat in lvis_train["categories"]}

# file_name (from coco_url) → image_info
file_name_to_image = {}
for img in lvis_train["images"] + lvis_val["images"]:
    if "coco_url" in img:
        file_name = img["coco_url"].split("/")[-1]
        file_name_to_image[file_name] = img

# image_id → annotations
image_id_to_annotations = defaultdict(list)
for ann in lvis_train["annotations"] + lvis_val["annotations"]:
    image_id_to_annotations[ann["image_id"]].append(ann)

# === Process and compute normalized area using bbox ===
output_json = {}

for obj_name, count_dict in tqdm(grouped_data.items(), desc="Processing objects"):
    cat_id = cat_name_to_id.get(obj_name)
    if cat_id is None:
        continue

    records = []
    for count_str, info in count_dict.items():
        image_paths = info.get("image_ids", [])
        for path in image_paths:
            file_name = path.split("/")[-1]
            img_info = file_name_to_image.get(file_name)
            if not img_info:
                continue

            image_id = img_info["id"]
            w, h = img_info["width"], img_info["height"]

            anns = image_id_to_annotations.get(image_id, [])
            for ann in anns:
                if ann["category_id"] == cat_id:
                    # compute normalized area from bbox
                    x, y, bw, bh = ann["bbox"]
                    bbox_area = bw * bh
                    norm_area = bbox_area / (w * h)
                    if 0.002 < norm_area < 0.5:
                        records.append({
                            "image_id": path,
                            "normalized_area": round(norm_area, 6)
                        })

    if records:
        mean_area = sum(r["normalized_area"] for r in records) / len(records)
        output_json[obj_name] = {
            "mean_normalized_area": round(mean_area, 6),
            "count": len(records),
            "images": records
        }

# === Save output ===
with open("normalized_area_per_object_lvis_grouped_bbx.json", "w") as f:
    json.dump(output_json, f, indent=2)

print(f"✅ Saved {len(output_json)} objects to normalized_area_per_object_lvis_grouped.json")


#%%
import json

with open("normalized_area_per_object_lvis_grouped.json", "r") as f:
    data = json.load(f)

count_500_plus = sum(1 for obj in data.values() if obj["count"] >= 500)

print(f"✅ Number of objects with ≥ 500 images: {len(data)}")
#Number of objects with ≥ 500 images: 1131
#Number of objects with ≥ 500 images: 110
#%%
import json
import random

random.seed(42)

target_objects = [
    "pizza", "refrigerator", "laptop_computer", "stove", "bathtub",
    "bench", "curtain", "computer_keyboard", "fireplug", "bicycle",
    "basket", "cellular_telephone", "backpack", "surfboard", "tennis_racket",
    "helmet", "baseball_glove", "faucet", "baseball_cap", "handbag"
]

with open("normalized_area_per_object_lvis_grouped_bbx.json", "r") as f:
    full_data = json.load(f)


filtered_output = {}
insufficient_objects = []

for obj in target_objects:
    if obj not in full_data:
        print(f"⚠️ {obj} not found in input JSON.")
        continue
    info = full_data[obj]
    images = info["images"]
    if len(images) < 700:
        insufficient_objects.append((obj, len(images)))
        continue
    sampled_images = random.sample(images, 700)
    filtered_output[obj] = {
        "mean_normalized_area": info["mean_normalized_area"],
        "count": 700,
        "images": sampled_images
    }

with open("final_top_lvis_700.json", "w") as f:
    json.dump(filtered_output, f, indent=2)

if insufficient_objects:
    print("🚫 Objects with fewer than 700 images:")
    for obj, count in insufficient_objects:
        print(f"- {obj}: {count} images")
else:
    print("✅ All target objects have ≥ 700 images.")


#%%

import os
import json
import cv2
import numpy as np
from PIL import Image
from typing import List, Literal
from typing import List, Tuple
from tqdm import tqdm

# === Config ===
input_json_path = "final_top_lvis_700.json"
train_json_path = "../../counting/LVIS/downloads/lvis_v1_train.json"
val_json_path = "../../counting/LVIS/downloads/lvis_v1_val.json"
img_root = "../../counting/LVIS/downloads"
output_folder = "no_padded_object_images"
os.makedirs(output_folder, exist_ok=True)

# === Padding logic ===
def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

def adjust_boxes_to_square(
    boxes: List[List[float]],
    fmt: Literal["corner"],
    orig_size: tuple[int, int]
) -> List[List[float]]:
    W, H = orig_size
    S = max(W, H)
    dx = 0 if W >= H else (H - W) / 2
    dy = 0 if H >= W else (W - H) / 2

    new_boxes = []
    for b in boxes:
        ax, ay, w, h = b
        xmin, ymin = ax, ay
        xmax, ymax = ax + w, ay + h
        xmin += dx; xmax += dx
        ymin += dy; ymax += dy
        xmin /= S; xmax /= S
        ymin /= S; ymax /= S
        new_boxes.append([
            max(0, min(1, xmin)),
            max(0, min(1, ymin)),
            max(0, min(1, xmax)),
            max(0, min(1, ymax))
        ])
    return new_boxes
def normalize_boxes(
    boxes: List[List[float]],
    orig_size: Tuple[int, int]
) -> List[List[float]]:
    W, H = orig_size
    normalized = []
    for b in boxes:
        xmin, ymin, box_w, box_h = b
        xmax, ymax = xmin + box_w, ymin + box_h
        normalized.append([
            max(0, min(1, xmin / W)),
            max(0, min(1, ymin / H)),
            max(0, min(1, xmax / W)),
            max(0, min(1, ymax / H))
        ])
    return normalized

# === Mean color ===
processor_means = [
    [0.48145466, 0.4578275, 0.40821073],
    [0.48145466, 0.4578275, 0.40821073],
    [0.485, 0.456, 0.406],
    [0.485, 0.456, 0.406],
    [0.5, 0.5, 0.5],
    [0.48145466, 0.4578275, 0.40821073],
    [0.485, 0.456, 0.406],
    [0.5, 0.5, 0.5],
    [0.5, 0.5, 0.5]
]
bg_rgb = tuple(int(c * 255) for c in np.mean(processor_means, axis=0))

# === Load metadata ===
with open(input_json_path, "r") as f:
    object_data = json.load(f)
with open(train_json_path, "r") as f:
    lvis_train = json.load(f)
with open(val_json_path, "r") as f:
    lvis_val = json.load(f)

all_images = {img["id"]: img for img in lvis_train["images"] + lvis_val["images"]}
annotations = lvis_train["annotations"] + lvis_val["annotations"]
cat_name_to_id = {cat["name"]: cat["id"] for cat in lvis_train["categories"]}

image_to_anns = {}
for ann in annotations:
    image_to_anns.setdefault(ann["image_id"], []).append(ann)

output_records = []
missing_images = []

for obj, info in tqdm(object_data.items(), desc="Processing objects"):
    cat_id = cat_name_to_id.get(obj)
    if cat_id is None:
        continue

    for item in info["images"]:
        file_path = item["image_id"]
        file_name = os.path.basename(file_path)
        image_id = int(file_name.split(".")[0])

        img_info = all_images.get(image_id)
        if not img_info:
            missing_images.append(file_path)
            continue

        width, height = img_info["width"], img_info["height"]
        full_path = os.path.join(img_root, file_path)
        if not os.path.exists(full_path):
            missing_images.append(file_path)
            continue

        anns = image_to_anns.get(image_id, [])
        target_anns = [a for a in anns if a["category_id"] == cat_id and a["area"] > 0]
        if len(target_anns) != 1:
            continue

        bbox = target_anns[0]["bbox"]  # [x, y, w, h]
        img_bgr = cv2.imread(full_path)
        if img_bgr is None:
            missing_images.append(file_path)
            continue

        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        pil_img = Image.fromarray(img_rgb)

        # padded = expand2square(pil_img, bg_rgb)
        padded = pil_img
        # padded_bbox = adjust_boxes_to_square([bbox], fmt="corner", orig_size=(width, height))[0]
        padded_bbox = normalize_boxes([bbox], orig_size=(width, height))[0]

        save_path = os.path.join(output_folder, file_name)
        padded.save(save_path)

        output_records.append({
            "object": obj,
            "image_path": save_path,
            "bbox": padded_bbox  # normalized [xmin, ymin, xmax, ymax]
        })

# === Save results ===
with open("no_padded_bbox_output.json", "w") as f:
    json.dump(output_records, f, indent=2)

with open("missing_images.txt", "w") as f:
    for path in missing_images:
        f.write(path + "\n")

print(f"✅ Done. Saved {len(output_records)} padded images with bbox info.")
print(f"🚫 Missing images: {len(missing_images)} saved to missing_images.txt")

#%%
import json, os, random
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

random.seed(42)
json_path = "no_padded_bbox_output.json"
output_dir = "debug_bbox_outputs"
os.makedirs(output_dir, exist_ok=True)

with open(json_path, "r") as f:
    data = json.load(f)

sampled = random.sample(data, 3)

for entry in sampled:
    img_path = entry["image_path"]
    obj = entry["object"]
    bbox = entry["bbox"]
    image_id = entry.get("image_id", os.path.basename(img_path).split(".")[0])

    try:
        img = Image.open(img_path).convert("RGB")
    except Exception as e:
        print(f"[ERROR] Cannot open image: {img_path} - {e}")
        continue

    if not all(0 <= v <= 1 for v in bbox):
        print(f"[WARNING] Skipping invalid bbox: {bbox}")
        continue

    w, h = img.size
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.imshow(img)

    xmin = bbox[0] * w
    ymin = bbox[1] * h
    box_w = (bbox[2] - bbox[0]) * w
    box_h = (bbox[3] - bbox[1]) * h

    rect = patches.Rectangle((xmin, ymin), box_w, box_h, linewidth=2, edgecolor='red', facecolor='none')
    ax.add_patch(rect)
    ax.text(xmin, max(0, ymin - 10), f"{obj} | ID: {image_id}", color='yellow', fontsize=12, backgroundcolor='black')

    ax.axis('off')
    save_path = os.path.join(output_dir, f"{image_id}_bbox.png")
    plt.savefig(save_path)
    plt.close()
    print(f"[✔] Saved {save_path}")


#%%
import json
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# === Inputs ===
target_image_id = 442106
target_category_name = "stove"
train_json_path = "../../counting/LVIS/downloads/lvis_v1_train.json"
val_json_path = "../../counting/LVIS/downloads/lvis_v1_val.json"
img_root = "../../counting/LVIS/downloads"

# === Load JSON ===
with open(train_json_path, "r") as f:
    train_data = json.load(f)
with open(val_json_path, "r") as f:
    val_data = json.load(f)

categories = train_data["categories"]
cat_name_to_id = {cat["name"]: cat["id"] for cat in categories}
cat_id = cat_name_to_id.get(target_category_name)
assert cat_id is not None, f"Category {target_category_name} not found"

# === Combine data ===
all_images = {img["id"]: img for img in train_data["images"] + val_data["images"]}
all_annotations = train_data["annotations"] + val_data["annotations"]

# === Find image info ===
img_info = all_images.get(target_image_id)
assert img_info, f"Image ID {target_image_id} not found"

W, H = img_info["width"], img_info["height"]
file_name = img_info["coco_url"].split("/")[-2] + "/" + img_info["coco_url"].split("/")[-1]
image_path = f"{img_root}/{file_name}"

# === Find matching annotations ===
anns = [a for a in all_annotations if a["image_id"] == target_image_id and a["category_id"] == cat_id]
assert anns, "No 'stove' annotation found in this image"

# === Compute bbox area and normalized area ===
for ann in anns:
    bbox = ann["bbox"]  # [x, y, w, h]
    area_from_bbox = bbox[2] * bbox[3]
    norm_area = area_from_bbox / (W * H)
    print(f"✅ BBox = {bbox}")
    print(f"📐 Computed bbox area = {area_from_bbox:.2f}")
    print(f"📐 Normalized area = {norm_area:.6f}")

    # === Visualize ===
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.imshow(img)
    ax.add_patch(patches.Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3],
                                   linewidth=2, edgecolor='red', facecolor='none'))
    ax.set_title(f"Stove bbox (Normalized Area: {norm_area:.4f})")
    ax.axis('off')
    plt.show()

