#%%
import os
import shutil
import xml.etree.ElementTree as ET
from collections import defaultdict

annotation_dir = "Annotations/Annotations/Horizontal Bounding Boxes"
backup_dir = "Annotations_backup"
os.makedirs(backup_dir, exist_ok=True)

for filename in os.listdir(annotation_dir):
    if not filename.endswith(".xml"):
        continue

    xml_path = os.path.join(annotation_dir, filename)
    backup_path = os.path.join(backup_dir, filename)

    shutil.copy2(xml_path, backup_path)

    tree = ET.parse(xml_path)
    root = tree.getroot()

    name_to_objects = defaultdict(list)
    for obj in root.findall("object"):
        name = obj.find("name").text.strip()
        name_to_objects[name].append(obj)

    for name, objs in name_to_objects.items():
        if len(objs) > 1:
            for obj in objs:
                root.remove(obj)

    tree.write(xml_path, encoding="utf-8", xml_declaration=True)

print("Done: removed all repeated object types (when repeated) and backed up originals.")

#%%
import os
import xml.etree.ElementTree as ET

annotation_dir = "Annotations/Annotations/Horizontal Bounding Boxes"

deleted_count = 0

for filename in os.listdir(annotation_dir):
    if not filename.endswith(".xml"):
        continue

    xml_path = os.path.join(annotation_dir, filename)

    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()

        if len(root.findall("object")) == 0:
            os.remove(xml_path)
            deleted_count += 1
            print(f"🗑️ Deleted: {filename} (no objects)")
    except Exception as e:
        print(f"⚠️ Failed to process {filename}: {e}")

print(f"✅ Done: Deleted {deleted_count} XML files with no objects.")


#%%
import os
import xml.etree.ElementTree as ET

annotation_dir = "Annotations/Annotations/Horizontal Bounding Boxes"
min_area = 0.002
max_area = 0.5

for filename in os.listdir(annotation_dir):
    if not filename.endswith(".xml"):
        continue

    xml_path = os.path.join(annotation_dir, filename)
    tree = ET.parse(xml_path)
    root = tree.getroot()

    size = root.find("size")
    width = int(size.find("width").text)
    height = int(size.find("height").text)
    total_area = width * height

    objects = root.findall("object")
    to_remove = []

    for obj in objects:
        bndbox = obj.find("bndbox")
        xmin = int(bndbox.find("xmin").text)
        ymin = int(bndbox.find("ymin").text)
        xmax = int(bndbox.find("xmax").text)
        ymax = int(bndbox.find("ymax").text)

        area = (xmax - xmin) * (ymax - ymin)
        norm_area = area / total_area

        if norm_area < min_area or norm_area > max_area:
            to_remove.append(obj)

    for obj in to_remove:
        root.remove(obj)

    if len(root.findall("object")) == 0:
        os.remove(xml_path)
        print(f"🗑️ Deleted {filename} (no valid objects)")
    else:
        tree.write(xml_path, encoding="utf-8", xml_declaration=True)

print("✅ Normalized area filtering done.")

#%%
import os
import xml.etree.ElementTree as ET
import json
from collections import defaultdict

annotation_dir = "Annotations/Annotations/Horizontal Bounding Boxes"
output_json = "filtered_annotation_summary.json"

result = defaultdict(lambda: {"1": {"count": 0, "image_ids": []}})

for filename in os.listdir(annotation_dir):
    if not filename.endswith(".xml"):
        continue

    xml_path = os.path.join(annotation_dir, filename)
    tree = ET.parse(xml_path)
    root = tree.getroot()

    image_filename = root.find("filename").text.strip()

    for obj in root.findall("object"):
        class_name = obj.find("name").text.strip()

        result[class_name]["1"]["image_ids"].append(image_filename)
        result[class_name]["1"]["count"] += 1

with open(output_json, "w") as f:
    json.dump(result, f, indent=2)

print(f"Saved summary to {output_json}")

#%%
import json
import matplotlib.pyplot as plt

with open("filtered_annotation_summary.json", "r") as f:
    data = json.load(f)

class_names = []
image_counts = []

for class_name, inner in data.items():
    count = inner["1"]["count"]
    if count > 500: 
        class_names.append(class_name)
        image_counts.append(count)

plt.figure(figsize=(12, 6))
plt.bar(class_names, image_counts)
plt.xticks(rotation=90, ha='right')
plt.xlabel("Class Name")
plt.ylabel("Number of Images")
plt.title("Number of Images per Class (Count > 500)")
plt.tight_layout()
plt.show()

#%%
import csv
import json

input_json = "filtered_annotation_summary.json"
output_csv = "filtered_annotation_summary.csv"

with open(input_json, "r") as f:
    data = json.load(f)

with open(output_csv, "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["object", "count"])
    for class_name, details in data.items():
        count = details["1"]["count"]
        writer.writerow([class_name, count])

print(f"Saved to {output_csv}")



#%%
import os
import xml.etree.ElementTree as ET
import numpy as np
from PIL import Image, ImageDraw
from collections import defaultdict

annotation_dir = "Annotations/Annotations/Horizontal Bounding Boxes"
image_dir = "JPEGImages-trainval/JPEGImages-trainval"
test_image_dir = "JPEGImages-test/JPEGImages-test"
output_dir = "min_area_visuals"
os.makedirs(output_dir, exist_ok=True)

processor_means = [
    [0.48145466, 0.4578275, 0.40821073],
    [0.48145466, 0.4578275, 0.40821073],
    [0.485, 0.456, 0.406],
    [0.485, 0.456, 0.406],
    [0.5, 0.5, 0.5],
    [0.48145466, 0.4578275, 0.40821073],
    [0.485, 0.456, 0.406],
    [0.5, 0.5, 0.5],
    [0.5, 0.5, 0.5]
]
mean_array = np.array(processor_means)
average_mean = np.mean(mean_array, axis=0)
background_color = tuple(int(x * 255) for x in average_mean)

def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

def adjust_boxes_to_square(boxes, orig_size):
    W, H = orig_size
    S = max(W, H)
    dx = 0 if W >= H else (H - W) / 2
    dy = 0 if H >= W else (W - H) / 2
    new_boxes = []
    for b in boxes:
        ax, ay, w, h = b
        xmin, ymin = ax, ay
        xmax, ymax = ax + w, ay + h
        xmin += dx; xmax += dx
        ymin += dy; ymax += dy
        xmin /= S; xmax /= S
        ymin /= S; ymax /= S
        new_boxes.append([
            max(0, min(1, xmin)),
            max(0, min(1, ymin)),
            max(0, min(1, xmax)),
            max(0, min(1, ymax))
        ])
    return new_boxes

class_to_areas = defaultdict(list)

for filename in os.listdir(annotation_dir):
    if not filename.endswith(".xml"):
        continue

    xml_path = os.path.join(annotation_dir, filename)
    tree = ET.parse(xml_path)
    root = tree.getroot()
    image_filename = root.find("filename").text.strip()
    size = root.find("size")
    W = int(size.find("width").text)
    H = int(size.find("height").text)
    total_area = W * H

    for obj in root.findall("object"):
        name = obj.find("name").text.strip()
        bbox = obj.find("bndbox")
        xmin = int(bbox.find("xmin").text)
        ymin = int(bbox.find("ymin").text)
        xmax = int(bbox.find("xmax").text)
        ymax = int(bbox.find("ymax").text)
        area = (xmax - xmin) * (ymax - ymin)
        norm_area = area / total_area

        class_to_areas[name].append({
            "area": norm_area,
            "file": image_filename,
            "box": [xmin, ymin, xmax, ymax],
            "size": (W, H)
        })

for class_name, entries in class_to_areas.items():
    entries.sort(key=lambda x: x["area"])
    top3 = entries[:3]
    for idx, entry in enumerate(top3):
        image_path = os.path.join(image_dir, entry["file"])
        if not os.path.exists(image_path):
            image_path = os.path.join(test_image_dir, entry["file"])
            if not os.path.exists(image_path):
                print(f"❌ Image not found: {entry['file']}")
                continue

        img = Image.open(image_path).convert("RGB")
        padded_img = expand2square(img, background_color)
        W, H = entry["size"]
        x0, y0, x1, y1 = entry["box"]
        box = [[x0, y0, x1 - x0, y1 - y0]]
        norm_box = adjust_boxes_to_square(box, (W, H))[0]

        draw = ImageDraw.Draw(padded_img)
        pxmin = int(norm_box[0] * padded_img.size[0])
        pymin = int(norm_box[1] * padded_img.size[1])
        pxmax = int(norm_box[2] * padded_img.size[0])
        pymax = int(norm_box[3] * padded_img.size[1])
        draw.rectangle([pxmin, pymin, pxmax, pymax], outline="red", width=3)
        draw.text((pxmin, pymin - 10), f"{class_name}", fill="red")

        save_name = f"{class_name}_{idx+1}_{entry['file']}"
        padded_img.save(os.path.join(output_dir, save_name))

print("✅ Done saving top-3 min area images per class.")

#%%
import os
import xml.etree.ElementTree as ET
import json
from collections import defaultdict

annotation_dir = "Annotations/Annotations/Horizontal Bounding Boxes"
output_json = "mean_normalized_area_per_class.json"

area_sums = defaultdict(float)
area_counts = defaultdict(int)

for filename in os.listdir(annotation_dir):
    if not filename.endswith(".xml"):
        continue

    xml_path = os.path.join(annotation_dir, filename)
    tree = ET.parse(xml_path)
    root = tree.getroot()

    W = int(root.find("size/width").text)
    H = int(root.find("size/height").text)
    total_area = W * H

    for obj in root.findall("object"):
        name = obj.find("name").text.strip()
        bbox = obj.find("bndbox")
        xmin = int(bbox.find("xmin").text)
        ymin = int(bbox.find("ymin").text)
        xmax = int(bbox.find("xmax").text)
        ymax = int(bbox.find("ymax").text)

        area = (xmax - xmin) * (ymax - ymin)
        norm_area = area / total_area

        area_sums[name] += norm_area
        area_counts[name] += 1

mean_areas = {
    name: round(area_sums[name] / area_counts[name], 6)
    for name in area_sums if area_counts[name] > 500
}

with open(output_json, "w") as f:
    json.dump(mean_areas, f, indent=2)

for k, v in mean_areas.items():
    print(f"{k:30s} -> mean normalized area: {v}")

print(f"\nSaved mean normalized area per class to {output_json}")

#%%
import json

input_json = "filtered_annotation_summary.json"
output_json = "filtered_annotation_summary_filtered.json"

with open(input_json, "r") as f:
    data = json.load(f)

target_objects = {
    "groundtrackfield",
    "dam",
    "golffield",
    "basketballcourt",
    "airport",
    "trainstation",
    "bridge",
    "stadium",
    "overpass",
    "baseballfield"
}

filtered = {k: v for k, v in data.items() if k in target_objects}

with open(output_json, "w") as f:
    json.dump(filtered, f, indent=2)

print(f"✅ Saved {len(filtered)} objects to {output_json}")

#%%
import json
import random

random.seed(42)

input_json = "filtered_annotation_summary_filtered.json"
output_json = "filtered_annotation_summary_top700.json"

with open(input_json, "r") as f:
    data = json.load(f)

output_data = {}

for obj, levels in data.items():
    if "1" not in levels:
        continue
    entry = levels["1"]
    count = entry["count"]
    image_ids = entry["image_ids"]

    if count > 700:
        sampled_ids = random.sample(image_ids, 700)
        output_data[obj] = {
            "1": {
                "count": 700,
                "image_ids": sampled_ids
            }
        }
    else:
        output_data[obj] = {
            "1": {
                "count": count,
                "image_ids": image_ids
            }
        }

with open(output_json, "w") as f:
    json.dump(output_data, f, indent=2)

print(f"✅ Saved top-500 limited version to {output_json}")

#%%
import os
import json
import xml.etree.ElementTree as ET
from tqdm import tqdm
from PIL import Image
import numpy as np
from typing import List, Literal

def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

def adjust_boxes_to_square(
    boxes: List[List[float]],
    fmt: Literal["corner"],
    orig_size: tuple[int, int]
) -> List[List[float]]:
    W, H = orig_size
    S = max(W, H)
    dx = 0 if W >= H else (H - W) / 2
    dy = 0 if H >= W else (W - H) / 2

    new_boxes = []
    for b in boxes:
        ax, ay, w, h = b
        xmin, ymin = ax, ay
        xmax, ymax = ax + w, ay + h
        xmin += dx; xmax += dx
        ymin += dy; ymax += dy
        xmin /= S; xmax /= S
        ymin /= S; ymax /= S
        new_boxes.append([
            max(0, min(1, xmin)),
            max(0, min(1, ymin)),
            max(0, min(1, xmax)),
            max(0, min(1, ymax))
        ])
    return new_boxes

processor_means = [
    [0.48145466, 0.4578275, 0.40821073],
    [0.48145466, 0.4578275, 0.40821073],
    [0.485, 0.456, 0.406],
    [0.485, 0.456, 0.406],
    [0.5, 0.5, 0.5],
    [0.48145466, 0.4578275, 0.40821073],
    [0.485, 0.456, 0.406],
    [0.5, 0.5, 0.5],
    [0.5, 0.5, 0.5]
]
bg_rgb = tuple(int(c * 255) for c in np.mean(processor_means, axis=0))

input_json = "filtered_annotation_summary_top700.json"
xml_folder = "Annotations/Annotations/Horizontal Bounding Boxes"
jpeg_test = "JPEGImages-test/JPEGImages-test"
jpeg_trainval = "JPEGImages-trainval/JPEGImages-trainval"
output_folder = "padded_object_images"
output_json = "filtered_bbox_output.json"
os.makedirs(output_folder, exist_ok=True)

with open(input_json, "r") as f:
    category_to_images = json.load(f)

results = []

for category, data in tqdm(category_to_images.items(), desc="Processing categories"):
    image_ids = data["1"]["image_ids"]
    for img_name in image_ids:
        img_id = os.path.splitext(img_name)[0]
        xml_path = os.path.join(xml_folder, f"{img_id}.xml")

        test_path = os.path.join(jpeg_test, img_name)
        trainval_path = os.path.join(jpeg_trainval, img_name)

        if os.path.exists(test_path):
            raw_img_path = test_path
        elif os.path.exists(trainval_path):
            raw_img_path = trainval_path
        else:
            print(f"[Missing Image] {img_name} not found.")
            continue

        if not os.path.exists(xml_path):
            print(f"[Missing XML] {xml_path}")
            continue

        try:
            tree = ET.parse(xml_path)
            root = tree.getroot()
            width = int(root.find("size/width").text)
            height = int(root.find("size/height").text)
            image = Image.open(raw_img_path).convert("RGB")
            padded = expand2square(image, bg_rgb)
            padded_path = os.path.join(output_folder, img_name)
            padded.save(padded_path)

            for obj in root.findall("object"):
                name = obj.find("name").text.strip()
                if name != category:
                    continue
                bbox = obj.find("bndbox")
                xmin = int(bbox.find("xmin").text)
                ymin = int(bbox.find("ymin").text)
                xmax = int(bbox.find("xmax").text)
                ymax = int(bbox.find("ymax").text)
                w = xmax - xmin
                h = ymax - ymin

                norm_bbox = adjust_boxes_to_square([[xmin, ymin, w, h]], fmt="corner", orig_size=(width, height))[0]

                results.append({
                    "object": category,
                    "image_path": padded_path,
                    "bbox": norm_bbox
                })
        except Exception as e:
            print(f"[Error] Failed on {img_id}: {e}")
            continue

with open(output_json, "w") as f:
    json.dump(results, f, indent=2)

print(f"✅ Done. Total entries: {len(results)}")

#%%
import json
import random
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

with open("filtered_bbox_output.json", "r") as f:
    data = json.load(f)

sample = random.choice(data)
img_path = sample["image_path"]
bbox = sample["bbox"]

img = Image.open(img_path)
W, H = img.size

x1 = bbox[0] * W
y1 = bbox[1] * H
x2 = bbox[2] * W
y2 = bbox[3] * H
w = x2 - x1
h = y2 - y1

fig, ax = plt.subplots(1)
ax.imshow(img)
rect = patches.Rectangle((x1, y1), w, h, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rect)
plt.axis("off")
plt.tight_layout()
plt.show()
