#%%
import json
import random
import csv
import matplotlib.pyplot as plt
from collections import defaultdict

random.seed(42)

# Load the full JSON
with open("zhiyuan_objv2_val.json", "r") as f:
    full_data = json.load(f)

images = full_data["images"]
annotations = full_data["annotations"]
categories = full_data["categories"]

# Build mappings
image_id_to_size = {img["id"]: (img["width"], img["height"]) for img in images}
image_id_to_file = {}
missing_file_name_images = []

for img in images:
    if "file_name" in img:
        image_id_to_file[img["id"]] = img["file_name"]
    else:
        missing_file_name_images.append(img["id"])

category_id_to_name = {cat["id"]: cat["name"] for cat in categories}

print(f"Total images without file_name: {len(missing_file_name_images)}")

# Step 1: Filter annotations (iscrowd=0 and isfake=0)
filtered_annotations = [
    ann for ann in annotations
    if ann["iscrowd"] == 0 and ann["isfake"] == 0
]

# Step 2: Group annotations by image
image_to_annotations = defaultdict(list)
for ann in filtered_annotations:
    image_to_annotations[ann["image_id"]].append(ann)

# Step 2-3: Find images where a category appears only once
category_to_images = defaultdict(list)

for img_id, anns in image_to_annotations.items():
    category_count = defaultdict(int)
    for ann in anns:
        category_count[ann["category_id"]] += 1

    for cat_id, count in category_count.items():
        if count == 1:  # only one bbox of this category in this image
            category_to_images[cat_id].append(img_id)

# Step 4: Plot how many images per category
sorted_categories = sorted(category_to_images.items(), key=lambda x: -len(x[1]))

# === Only keep top 100 ===
sorted_categories = sorted_categories[:100]

x_labels = [category_id_to_name[cat_id] for cat_id, _ in sorted_categories]
y_values = [len(img_ids) for _, img_ids in sorted_categories]

plt.figure(figsize=(20, 8))
plt.bar(x_labels, y_values)
plt.xticks(rotation=90)
plt.ylabel("Number of Images (Single-instance)")
plt.title("Top 100 Single-instance Images per Category (iscrowd=0, isfake=0)")
plt.tight_layout()
plt.savefig("top100_category_single_instance_count_filtered.png")
plt.show()


#%%
# Step 5: Random sample 50 images per category and calculate normalized area
records = []

for cat_id, img_ids in category_to_images.items():
    sample_imgs = random.sample(img_ids, min(50, len(img_ids)))
    for img_id in sample_imgs:
        # Skip if file_name missing
        if img_id not in image_id_to_file:
            continue

        anns = [ann for ann in image_to_annotations[img_id] if ann["category_id"] == cat_id]
        if not anns:
            continue
        ann = anns[0]  # only one by our condition
        width, height = image_id_to_size[img_id]
        bbox_area = ann["area"]
        normalized_area = bbox_area / (width * height)

        record = {
            "image_id": img_id,
            "file_name": image_id_to_file[img_id],
            "category_id": cat_id,
            "category_name": category_id_to_name[cat_id],
            "normalized_area": normalized_area
        }
        records.append(record)

# Step 6: Save to CSV
csv_fieldnames = ["image_id", "file_name", "category_id", "category_name", "normalized_area"]

with open("single_instance_sampled_filtered.csv", "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=csv_fieldnames)
    writer.writeheader()
    for row in records:
        writer.writerow(row)

print("Done! Outputs generated:")
print("- category_single_instance_count_filtered.png")
print("- single_instance_sampled_filtered.csv")
print(f"- Images missing file_name: {len(missing_file_name_images)} images")

#%%
import json

# Load your JSON
with open("zhiyuan_objv2_val.json", "r") as f:
    data = json.load(f)

# Fix the file_name for each image
for img in data["images"]:
    if "file_name" in img:
        img["file_name"] = img["file_name"].replace("v1_ex", "v1").replace("v2_ex", "v2")

# Save back to a new JSON file
with open("zhiyuan_objv2_val.json", "w") as f:
    json.dump(data, f, indent=2)

print("File name corrections done! Saved to your_full_data_fixed.json")

#%%
import pandas as pd

# Load CSV
csv_file = "single_instance_sampled_filtered.csv"
df = pd.read_csv(csv_file)

# Fix file_name
df["file_name"] = df["file_name"].str.replace("v1_ex", "v1").str.replace("v2_ex", "v2")

# Save to new CSV
df.to_csv("normalized_area_records.csv", index=False)

print("File name corrections done! Saved to single_instance_sampled_filtered_fixed.csv")

#%%
import pandas as pd
import random, json, cv2, os
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image, ImageDraw
import numpy as np
from tqdm import tqdm

random.seed(42)

csv_file = "normalized_area_records.csv"
df = pd.read_csv(csv_file)
# target_cat = "Ring"
# df = df[df["category_name"].str.lower() == target_cat.lower()]
selected_samples = df.sample(n=3, random_state=42)

with open("zhiyuan_objv2_val.json", "r") as f:
    full_data = json.load(f)

ann_lookup = {(a["image_id"], a["category_id"]): a
              for a in full_data["annotations"]
              if a["iscrowd"] == 0 and a["isfake"] == 0}

def expand2square(pil_img, bg_color):
    w, h = pil_img.size
    if w == h:
        return pil_img, (0, 0)
    size = max(w, h)
    offset = ((size - w) // 2, (size - h) // 2)
    new_img = Image.new(pil_img.mode, (size, size), bg_color)
    new_img.paste(pil_img, offset)
    return new_img, offset

means = [
    [0.48145466, 0.4578275, 0.40821073],
    [0.48145466, 0.4578275, 0.40821073],
    [0.485, 0.456, 0.406],
    [0.485, 0.456, 0.406],
    [0.5, 0.5, 0.5],
    [0.48145466, 0.4578275, 0.40821073],
    [0.485, 0.456, 0.406],
    [0.5, 0.5, 0.5],
    [0.5, 0.5, 0.5]
]
bg_rgb = tuple(int(c * 255) for c in np.mean(means, axis=0))

output_root = "new"
os.makedirs(output_root, exist_ok=True)
bbox_records = []

for _, row in tqdm(selected_samples.iterrows(), total=3, desc="Process"):
    img_path = row["file_name"]
    img_id, cat_id = row["image_id"], row["category_id"]
    norm_area = row["normalized_area"]

    ann = ann_lookup.get((img_id, cat_id))
    if ann is None or not os.path.exists(img_path):
        print(f"[WARN] 跳过 {img_path}")
        continue

    bbox = ann["bbox"]

    img_bgr = cv2.imread(img_path)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

    fig, ax = plt.subplots(figsize=(8, 8))
    ax.imshow(img_rgb)
    ax.add_patch(patches.Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3],
                                   linewidth=2, edgecolor='red', facecolor='none'))
    ax.text(bbox[0], bbox[1] - 10, f"Normalized Area: {norm_area:.4f}",
            color='yellow', fontsize=12, backgroundcolor='black')
    ax.axis('off')
    plt.show()

    pil_img = Image.fromarray(img_rgb)
    padded, (off_x, off_y) = expand2square(pil_img, bg_rgb)
    new_bbox = (bbox[0] + off_x, bbox[1] + off_y, bbox[2], bbox[3])

    draw = ImageDraw.Draw(padded)
    x, y, w_, h_ = new_bbox
    draw.rectangle([x, y, x + w_, y + h_], outline="red", width=2)

    save_name = os.path.basename(img_path)
    save_path = os.path.join(output_root, save_name)
    padded.save(save_path)

    bbox_records.append({
        "file_name": save_name,
        "padded_bbox": [float(new_bbox[0]), float(new_bbox[1]), float(new_bbox[2]), float(new_bbox[3])],
        "normalized_area": float(norm_area)
    })

with open(os.path.join(output_root, "bbox_records.json"), "w") as f:
    json.dump(bbox_records, f, indent=2)


#%%
import csv
import json
from collections import defaultdict

# === Step 1: Load mean normalized area ===
mean_area_file = "normalized_area_records.csv"
mean_area_dict = {}

with open(mean_area_file, "r") as f:
    reader = csv.DictReader(f)
    for row in reader:
        obj_name = row["category_name"]
        mean_area = float(row["normalized_area"])
        if obj_name not in mean_area_dict:
            mean_area_dict[obj_name] = []
        mean_area_dict[obj_name].append(mean_area)

# Take mean of all normalized areas per object
mean_area_dict = {
    obj: sum(areas) / len(areas)
    for obj, areas in mean_area_dict.items()
}

# === Step 2: Load COCO JSON and count single-instance images ===
with open("zhiyuan_objv2_val.json", "r") as f:
    full_data = json.load(f)

annotations = full_data["annotations"]
categories = full_data["categories"]
cat_id_to_name = {cat["id"]: cat["name"] for cat in categories}

# Filter valid annotations
filtered_annotations = [
    ann for ann in annotations
    if ann["iscrowd"] == 0 and ann["isfake"] == 0
]

# Group annotations per image
image_to_annotations = defaultdict(list)
for ann in filtered_annotations:
    image_to_annotations[ann["image_id"]].append(ann)

# Count single-instance occurrences
category_to_single_instance_images = defaultdict(set)

for img_id, anns in image_to_annotations.items():
    category_count = defaultdict(int)
    for ann in anns:
        category_count[ann["category_id"]] += 1

    for cat_id, count in category_count.items():
        if count == 1:
            category_to_single_instance_images[cat_id].add(img_id)

# === Step 3: Merge and write output ===
output_file = "365_stats.csv"

with open(output_file, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["object name", "mean_normalized_area", "number_single_instance_images"])

    for cat_id, img_ids in category_to_single_instance_images.items():
        obj_name = cat_id_to_name[cat_id]
        if obj_name in mean_area_dict:
            writer.writerow([
                obj_name,
                f"{mean_area_dict[obj_name]:.6f}",
                len(img_ids)
            ])

print(f"✅ Done! Saved to {output_file}")


#%%
import json
from collections import defaultdict

# Load full annotation JSON
with open("zhiyuan_objv2_val.json", "r") as f:
    full_data = json.load(f)

images = full_data["images"]
annotations = full_data["annotations"]
categories = full_data["categories"]

image_id_to_size = {img["id"]: (img["width"], img["height"]) for img in images}
image_id_to_file = {img["id"]: img["file_name"] for img in images if "file_name" in img}
category_id_to_name = {cat["id"]: cat["name"] for cat in categories}

# Step 1: Pre-filter annotations
filtered_annotations = [
    ann for ann in annotations
    if ann["iscrowd"] == 0 and ann["isfake"] == 0
]

# Step 2: Group annotations per image
image_to_annotations = defaultdict(list)
for ann in filtered_annotations:
    image_to_annotations[ann["image_id"]].append(ann)

# Step 3: Collect valid records with only one instance per category and normalized area in (0.002, 0.5)
category_data = defaultdict(list)

for img_id, anns in image_to_annotations.items():
    width, height = image_id_to_size[img_id]
    category_to_anns = defaultdict(list)

    for ann in anns:
        category_to_anns[ann["category_id"]].append(ann)

    for cat_id, cat_anns in category_to_anns.items():
        if len(cat_anns) == 1:
            ann = cat_anns[0]
            x, y, w_box, h_box = ann["bbox"]
            norm_area = (w_box * h_box) / (width * height)
            if 0.002 < norm_area < 0.5:
                cat_name = category_id_to_name[cat_id]
                category_data[cat_name].append({
                    "image_id": img_id,
                    "normalized_area": round(norm_area, 6)
                })

# Step 4: Generate summary JSON
summary = {}
for cat_name, entries in category_data.items():
    areas = [e["normalized_area"] for e in entries]
    mean_area = sum(areas) / len(areas)
    summary[cat_name] = {
        "mean_normalized_area": round(mean_area, 6),
        "count": len(entries),
        "images": entries
    }

# Step 5: Save to JSON file
with open("normalized_area_per_object_detailed_bbx.json", "w") as f:
    json.dump(summary, f, indent=2)

print(f"✅ Saved summary for {len(summary)} objects to normalized_area_per_object_detailed.json")


#%%
import os
import json
import numpy as np
from PIL import Image
from typing import List, Literal
from tqdm import tqdm
from collections import defaultdict, Counter
import random

# === Config ===
annotation_json_path = "zhiyuan_objv2_val.json"
output_json_path = "object365_no_padded_bbox_output.json"
output_folder = "no_padded_object_images"
os.makedirs(output_folder, exist_ok=True)

# === Padding and bbox adjustment ===
def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

def adjust_boxes_to_square(
    boxes: List[List[float]],
    fmt: Literal["corner"],
    orig_size: tuple[int, int]
) -> List[List[float]]:
    W, H = orig_size
    S = max(W, H)
    dx = 0 if W >= H else (H - W) / 2
    dy = 0 if H >= W else (W - H) / 2

    new_boxes = []
    for b in boxes:
        ax, ay, w, h = b
        xmin, ymin = ax, ay
        xmax, ymax = ax + w, ay + h
        xmin += dx; xmax += dx
        ymin += dy; ymax += dy
        xmin /= S; xmax /= S
        ymin /= S; ymax /= S
        new_boxes.append([
            max(0, min(1, xmin)),
            max(0, min(1, ymin)),
            max(0, min(1, xmax)),
            max(0, min(1, ymax))
        ])
    return new_boxes
def normalize_boxes(
    boxes: List[List[float]],
    orig_size: tuple[int, int]
) -> List[List[float]]:
    W, H = orig_size
    normalized = []
    for b in boxes:
        xmin, ymin, box_w, box_h = b
        xmax, ymax = xmin + box_w, ymin + box_h
        normalized.append([
            max(0, min(1, xmin / W)),
            max(0, min(1, ymin / H)),
            max(0, min(1, xmax / W)),
            max(0, min(1, ymax / H))
        ])
    return normalized

# === Background color ===
processor_means = [
    [0.48145466, 0.4578275, 0.40821073],
    [0.48145466, 0.4578275, 0.40821073],
    [0.485, 0.456, 0.406],
    [0.485, 0.456, 0.406],
    [0.5, 0.5, 0.5],
    [0.48145466, 0.4578275, 0.40821073],
    [0.485, 0.456, 0.406],
    [0.5, 0.5, 0.5],
    [0.5, 0.5, 0.5]
]
bg_rgb = tuple(int(x * 255) for x in np.mean(processor_means, axis=0))

# === Load and filter data ===
random.seed(42)
TARGET_OBJECTS = [
    "Ship", "Couch", "Person", "Mirror", "Chair", "Bread", "Guitar",
    "Umbrella", "Toilet", "Flower", "Knife", "SUV", "Printer",
    "Street Lights", "Shovel", "Lemon", "Microwave", "Kettle", "Vase", "Camera"
]

with open(annotation_json_path, "r") as f:
    full_data = json.load(f)

image_id_to_info = {img["id"]: img for img in full_data["images"]}
image_id_to_anns = defaultdict(list)
for ann in full_data["annotations"]:
    if ann.get("iscrowd", 0) == 0 and ann.get("isfake", 0) == 0:
        image_id_to_anns[ann["image_id"]].append(ann)

category_id_to_name = {cat["id"]: cat["name"] for cat in full_data["categories"]}
category_name_to_id = {v: k for k, v in category_id_to_name.items()}

# Load area stats
with open("normalized_area_per_object_detailed.json", "r") as f:
    per_obj_stats = json.load(f)

strict_output = {}
insufficient = []

for obj_name in TARGET_OBJECTS:
    if obj_name not in per_obj_stats:
        continue

    entries = per_obj_stats[obj_name]["images"]
    cat_id = category_name_to_id[obj_name]
    valid_entries = []
    for e in entries:
        anns_here = image_id_to_anns.get(e["image_id"], [])
        if sum(1 for a in anns_here if a["category_id"] == cat_id) == 1:
            valid_entries.append(e)

    if len(valid_entries) < 700:
        print(f"[{obj_name}] Only found {len(valid_entries)} valid entries (using all).")
        sampled_entries = valid_entries
    else:
        sampled_entries = random.sample(valid_entries, 700)

    strict_output[obj_name] = {
        "mean_normalized_area": per_obj_stats[obj_name]["mean_normalized_area"],
        "count": len(sampled_entries),
        "images": sampled_entries
    }

# === Process and save padded images ===
output_records = []
missing_images = []

for obj, info in tqdm(strict_output.items(), desc="Processing objects"):
    cat_id = category_name_to_id[obj]
    for entry in info["images"]:
        img_id = entry["image_id"]
        img_info = image_id_to_info.get(img_id)
        if not img_info or "file_name" not in img_info:
            missing_images.append(f"{img_id} - no file name")
            continue

        orig_path = img_info["file_name"]
        corrected_path = orig_path.replace("images/v1/", "images/v1_ex/").replace("images/v2/", "images/v2_ex/")
        if not os.path.exists(corrected_path):
            missing_images.append(corrected_path)
            continue

        width, height = img_info["width"], img_info["height"]
        anns = image_id_to_anns.get(img_id, [])
        obj_anns = [a for a in anns if a["category_id"] == cat_id]
        if len(obj_anns) != 1:
            continue

        bbox = obj_anns[0]["bbox"]
        try:
            pil_img = Image.open(corrected_path).convert("RGB")
        except:
            missing_images.append(corrected_path)
            continue

        padded_img = pil_img
        # padded_img = expand2square(pil_img, bg_rgb)
        # norm_bbox = adjust_boxes_to_square([bbox], fmt="corner", orig_size=(width, height))[0]
        norm_bbox = normalize_boxes([bbox], orig_size=(width, height))[0]
        save_name = os.path.basename(corrected_path)
        save_path = os.path.join(output_folder, save_name)
        padded_img.save(save_path)

        output_records.append({
            "object": obj,
            "image_path": save_path,
            "bbox": norm_bbox
        })

# === Save outputs ===
with open(output_json_path, "w") as f:
    json.dump(output_records, f, indent=2)
with open("missing_images.txt", "w") as f:
    for m in missing_images:
        f.write(m + "\n")

object_counter = Counter([r["object"] for r in output_records])
print("\n📊 Entry count per object:")
for obj, count in sorted(object_counter.items()):
    print(f"- {obj}: {count} entries")

print(f"\n✅ Total unique images: {len(set(r['image_path'] for r in output_records))}")
print(f"✅ Total object entries: {len(output_records)}")
print(f"🚫 Missing images: {len(missing_images)} saved to missing_images.txt")






#%%
import json
import random
import os
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt

# === Load JSON ===
with open("object365_no_padded_bbox_output.json", "r") as f:
    data = json.load(f)

# === Randomly select one entry ===
entry = random.choice(data)

image_path = entry["image_path"]
bbox = entry["bbox"]  # [xmin, ymin, xmax, ymax] normalized
obj_name = entry["object"]

# === Load image ===
if not os.path.exists(image_path):
    print(f"❌ Image not found: {image_path}")
else:
    img = Image.open(image_path).convert("RGB")
    draw = ImageDraw.Draw(img)

    W, H = img.size
    xmin = int(bbox[0] * W)
    ymin = int(bbox[1] * H)
    xmax = int(bbox[2] * W)
    ymax = int(bbox[3] * H)

    draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=3)
    draw.text((xmin, ymin - 10), obj_name, fill="yellow")

    # === Show image with bbox ===
    plt.figure(figsize=(8, 8))
    plt.imshow(img)
    plt.axis("off")
    plt.title(f"Object: {obj_name}")
    plt.show()

# %%