#%%
import json
from collections import defaultdict

ordered_species = [
    "Sciurus niger",
    "Sciurus carolinensis",
    "Odocoileus virginianus",
    "Canis latrans",
    "Procyon lotor",
    "Otospermophilus beecheyi",
    "Sylvilagus floridanus",
    "Odocoileus hemionus",
    "Tamias striatus",
    "Vulpes vulpes"
]
target_species_set = set(ordered_species)

filtered_data = []
species_image_count = defaultdict(int)

with open("mammalia_combined_filtered.json", "r") as f:
    data = json.load(f)

for entry in data:
    file_path_parts = entry["file_name"].split("/")
    if len(file_path_parts) < 3:
        continue  
    species = file_path_parts[2]

    if species not in target_species_set:
        continue

    annotations = entry.get("annotations", [])
    if len(annotations) != 1:
        continue 

    ann = annotations[0]
    bbox = ann.get("bbox", [])
    if len(bbox) != 4:
        continue  

    width = entry.get("width")
    height = entry.get("height")
    if not width or not height:
        continue  

    box_area = bbox[2] * bbox[3]
    norm_area = box_area / (width * height)

    if 0.002 < norm_area < 0.5:
        new_entry = entry.copy()
        new_entry["annotations"] = [ann]
        filtered_data.append(new_entry)
        species_image_count[species] += 1

with open("mammalia_filter.json", "w") as f:
    json.dump(filtered_data, f, indent=2)

with open("mammalia_filter.txt", "w") as f:
    for species in ordered_species:
        count = species_image_count.get(species, 0)
        f.write(f"{species}: {count}\n")


#%%
import os
import json
import cv2
import numpy as np
from PIL import Image
from typing import List, Literal
from tqdm import tqdm

# === Config ===
input_json_path = "mammalia_filter.json"
img_root = "../data/train_val_images"
output_folder = "no_padded_object_images"
os.makedirs(output_folder, exist_ok=True)

# === Background color mean (CANNOT BE CHANGED) ===
processor_means = [
    [0.48145466, 0.4578275, 0.40821073],
    [0.48145466, 0.4578275, 0.40821073],
    [0.485, 0.456, 0.406],
    [0.485, 0.456, 0.406],
    [0.5, 0.5, 0.5],
    [0.48145466, 0.4578275, 0.40821073],
    [0.485, 0.456, 0.406],
    [0.5, 0.5, 0.5],
    [0.5, 0.5, 0.5]
]
bg_rgb = tuple(int(c * 255) for c in np.mean(processor_means, axis=0))

# === Padding (CANNOT BE CHANGED) ===
def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

def adjust_boxes_to_square(
    boxes: List[List[float]],
    fmt: Literal["corner"],
    orig_size: tuple[int, int]
) -> List[List[float]]:
    W, H = orig_size
    S = max(W, H)
    dx = 0 if W >= H else (H - W) / 2
    dy = 0 if H >= W else (W - H) / 2

    new_boxes = []
    for b in boxes:
        ax, ay, w, h = b
        xmin, ymin = ax, ay
        xmax, ymax = ax + w, ay + h
        xmin += dx; xmax += dx
        ymin += dy; ymax += dy
        xmin /= S; xmax /= S
        ymin /= S; ymax /= S
        new_boxes.append([
            max(0, min(1, xmin)),
            max(0, min(1, ymin)),
            max(0, min(1, xmax)),
            max(0, min(1, ymax))
        ])
    return new_boxes
def normalize_boxes(
    boxes: List[List[float]],
    orig_size: tuple[int, int]
) -> List[List[float]]:
    W, H = orig_size
    normalized = []
    for b in boxes:
        xmin, ymin, box_w, box_h = b
        xmax, ymax = xmin + box_w, ymin + box_h
        normalized.append([
            max(0, min(1, xmin / W)),
            max(0, min(1, ymin / H)),
            max(0, min(1, xmax / W)),
            max(0, min(1, ymax / H))
        ])
    return normalized

# === Load input JSON ===
with open(input_json_path, "r") as f:
    data = json.load(f)

output_records = []
missing_images = []

for item in tqdm(data, desc="Processing images"):
    image_id = item["image_id"]
    file_name = item["file_name"]
    full_path = os.path.join(img_root, file_name)

    parts = file_name.split('/')
    if len(parts) < 3:
        continue
    object_name = parts[2]  # e.g., "Pica nuttalli"

    if not os.path.exists(full_path):
        missing_images.append(full_path)
        continue

    anns = [a for a in item["annotations"] if a.get("iscrowd", 0) == 0 and a.get("area", 0) > 0]
    if len(anns) == 0:
        continue

    try:
        img_bgr = cv2.imread(full_path)
        if img_bgr is None:
            missing_images.append(full_path)
            continue
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        pil_img = Image.fromarray(img_rgb)
    except:
        missing_images.append(full_path)
        continue

    # padded = expand2square(pil_img, bg_rgb)
    padded = pil_img
    width, height = item["width"], item["height"]
    padded_bboxes = normalize_boxes([a["bbox"] for a in anns], orig_size=(width, height))

    # padded_bboxes = adjust_boxes_to_square([a["bbox"] for a in anns], "corner", (width, height))

    save_path = os.path.join(output_folder, os.path.basename(file_name))
    padded.save(save_path)

    for ann, bbox in zip(anns, padded_bboxes):
        x1, y1, x2, y2 = bbox
        normalized_area = round((x2 - x1) * (y2 - y1), 4)

        output_records.append({
            "object": object_name,
            "image_path": save_path,
            "bbox": bbox,
            "area": normalized_area
        })

# === Save results ===
with open("mammalia_no_padded_bbox_output.json", "w") as f:
    json.dump(output_records, f, indent=2)

with open("missing_images.txt", "w") as f:
    for path in missing_images:
        f.write(path + "\n")

print(f"✅ Done. Saved {len(output_records)} padded bbox entries with area.")
print(f"🚫 Missing images: {len(missing_images)} saved to missing_images.txt")
#%%
import json
import random
from collections import defaultdict

with open("mammalia_no_padded_bbox_output.json", "r") as f:
    data = json.load(f)

object_to_entries = defaultdict(list)
for entry in data:
    object_to_entries[entry["object"]].append(entry)

random.seed(42)

final_output = []
for obj, entries in object_to_entries.items():
    if len(entries) > 700:
        sampled = random.sample(entries, 700)
    else:
        sampled = entries
    final_output.extend(sampled)

with open("mammalia_no_padded_bbox_output_700.json", "w") as f:
    json.dump(final_output, f, indent=2)

print(f"✅ Done. Saved {len(final_output)} entries to mammalia_padded_bbox_output_700.json")


#%%
import json
from collections import Counter

with open("mammalia_no_padded_bbox_output_700.json", "r") as f:
    data = json.load(f)

object_counts = Counter(entry["object"] for entry in data)

print("📊 Image count per object:")
for obj, count in object_counts.items():
    print(f"- {obj}: {count} images")

total = sum(object_counts.values())
print(f"\n✅ Total entries: {total}")

#%%
import json
import random
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

# random.seed(42)

with open("mammalia_no_padded_bbox_output_700.json", "r") as f:
    data = json.load(f)

sample = random.choice(data)
img_path = sample["image_path"]
bbox = sample["bbox"] 
obj_name = sample["object"]

img = Image.open(img_path)
W, H = img.size

x1 = bbox[0] * W
y1 = bbox[1] * H
x2 = bbox[2] * W
y2 = bbox[3] * H
w = x2 - x1
h = y2 - y1

fig, ax = plt.subplots(1)
ax.imshow(img)
rect = patches.Rectangle((x1, y1), w, h, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rect)
plt.title(f"Object: {obj_name}")
plt.axis("off")
plt.show()