#%%
import json

with open("train_2017_bboxes/train_2017_bboxes.json", "r") as f:
    data = json.load(f)

# Step 1: Find all category ids for Mammalia
mammalia_cat_ids = set(
    cat["id"] for cat in data["categories"] if cat["supercategory"] == "Aves"
)

# Step 2: Filter annotations
mammalia_annotations = [
    ann for ann in data["annotations"] if ann["category_id"] in mammalia_cat_ids
]

# Step 3: Collect image_ids from the remaining annotations
valid_image_ids = set(ann["image_id"] for ann in mammalia_annotations)

# Step 4: Filter images based on image_id
mammalia_images = [
    img for img in data["images"] if img["id"] in valid_image_ids
]

# Step 5: Filter categories as well (optional, just those used)
mammalia_categories = [
    cat for cat in data["categories"] if cat["id"] in mammalia_cat_ids
]

# Step 6: Save to new file
filtered_data = {
    "images": mammalia_images,
    "annotations": mammalia_annotations,
    "categories": mammalia_categories
}

with open("Aves_only_train.json", "w") as f:
    json.dump(filtered_data, f, indent=2)


#%%
#Done! 27066 images saved. mammalia
#Done! 222950 images saved. Aves
import json
from collections import defaultdict

with open("Aves_only_train.json", "r") as f:
    train_data = json.load(f)

with open("Aves_only_val.json", "r") as f:
    val_data = json.load(f)

merged_images = train_data["images"] + val_data["images"]
merged_annotations = train_data["annotations"] + val_data["annotations"]
category_map = {cat["id"]: cat for cat in (train_data["categories"] + val_data["categories"])}
merged_categories = list(category_map.values())

filtered_annotations = [
    ann for ann in merged_annotations if ann["iscrowd"] == 0
]

valid_image_ids = set(ann["image_id"] for ann in filtered_annotations)

filtered_images = [
    img for img in merged_images if img["id"] in valid_image_ids
]

image_id_to_info = {img["id"]: img for img in filtered_images}

image_annotations = defaultdict(list)
for ann in filtered_annotations:
    image_annotations[ann["image_id"]].append({
        "category_id": ann["category_id"],
        "iscrowd": ann["iscrowd"],
        "bbox": ann["bbox"],
        "area": ann["area"]
    })

final_output = []
for image_id, anns in image_annotations.items():
    image_info = image_id_to_info[image_id]
    final_output.append({
        "image_id": image_id,
        "file_name": image_info["file_name"],
        "width": image_info["width"],
        "height": image_info["height"],
        "annotations": anns
    })

with open("Aves_combined_filtered.json", "w") as f:
    json.dump(final_output, f, indent=2)

print(f"Done! {len(final_output)} images saved.")


#%%
import json

with open("Aves_combined_filtered.json", "r") as f:
    all_data = json.load(f)

single_bbox_images = [sample for sample in all_data if len(sample["annotations"]) == 1]

print(len(single_bbox_images))

with open("single_bbox_images_aves.json", "w") as f:
    json.dump(single_bbox_images, f, indent=2)

#%%
#23445
import os
import json
import random
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image, ImageDraw
from tqdm import tqdm

random.seed(42)
output_root = "new_aves"
os.makedirs(output_root, exist_ok=True)

def expand2square(pil_img, bg_color):
    w, h = pil_img.size
    if w == h:
        return pil_img, (0, 0)
    size = max(w, h)
    offset = ((size - w)//2, (size - h)//2)  # (left, top)
    new_img = Image.new(pil_img.mode, (size, size), bg_color)
    new_img.paste(pil_img, offset)
    return new_img, offset

means = [
    [0.48145466, 0.4578275, 0.40821073],
    [0.48145466, 0.4578275, 0.40821073],
    [0.485, 0.456, 0.406],
    [0.485, 0.456, 0.406],
    [0.5, 0.5, 0.5],
    [0.48145466, 0.4578275, 0.40821073],
    [0.485, 0.456, 0.406],
    [0.5, 0.5, 0.5],
    [0.5, 0.5, 0.5]
]
bg_rgb = tuple(int(c*255) for c in np.mean(means, axis=0))

with open("single_bbox_images_aves.json", "r") as f:
    all_data = json.load(f)

selected_samples = []
tries = 0
while len(selected_samples) < 3 and tries < 50:
    sample = random.choice(all_data)
    img_path = os.path.join("../data/train_val_images", sample["file_name"])
    if os.path.exists(img_path):
        selected_samples.append(sample)
    tries += 1

for sample in tqdm(selected_samples, total=3, desc="Processing"):
    img_path = os.path.join("../data/train_val_images", sample["file_name"])
    image_id = sample["image_id"]
    width = sample["width"]
    height = sample["height"]
    anns = sample["annotations"]

    if not os.path.exists(img_path):
        print(f"[WARN] Image not found: {img_path}")
        continue

    pil_img = Image.open(img_path).convert("RGB")
    draw = ImageDraw.Draw(pil_img)

    for ann in anns:
        bbox = ann["bbox"]
        area = ann["area"]
        normalized_area = area / (width * height)

        x, y, w_box, h_box = bbox
        draw.rectangle([x, y, x+w_box, y+h_box], outline="red", width=4)

        draw.text((x, y-15), f"{normalized_area:.4f}", fill="yellow")

    padded_img, _ = expand2square(pil_img, bg_rgb)

    save_path = os.path.join(output_root, os.path.basename(img_path))
    padded_img.save(save_path)

#%%
import json
from collections import Counter
import matplotlib.pyplot as plt

with open("single_bbox_images_aves.json", "r") as f:
    data = json.load(f)

species_counter = Counter()
for item in data:
    parts = item["file_name"].split("/")
    if len(parts) >= 4:
        species = parts[-2]
        species_counter[species] += 1

top_100 = species_counter.most_common(100)
species, counts = zip(*top_100)

plt.figure(figsize=(20, 6))
plt.bar(species, counts)
plt.xticks(rotation=90, ha='right')
plt.xlabel("Species")
plt.ylabel("Image Count")
plt.title("Top 100 Species by Image Count (Single BBox Only)")
plt.tight_layout()
plt.show()




#%%
import json
import csv
from collections import defaultdict

input_json = "Aves_combined_filtered.json"
output_csv = "aves_mean_normalized_area_per_object.csv"

with open(input_json, "r") as f:
    data = json.load(f)

object_to_areas = defaultdict(list)

for item in data:
    if len(item["annotations"]) != 1:
        continue  

    file_parts = item["file_name"].split("/")
    if len(file_parts) >= 4:
        object_name = file_parts[2] 
        width = item["width"]
        height = item["height"]
        area = item["annotations"][0]["area"]
        norm_area = area / (width * height)
        object_to_areas[object_name].append(norm_area)

output_rows = []
for obj, areas in object_to_areas.items():
    mean_area = sum(areas) / len(areas)
    output_rows.append({
        "object": obj,
        "mean_normalized_area": round(mean_area, 6),
        "number_images": len(areas)
    })

with open(output_csv, "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=["object", "mean_normalized_area", "number_images"])
    writer.writeheader()
    writer.writerows(output_rows)

print(f"✅ Saved summary for {len(output_rows)} objects to {output_csv}")
