#%%
import json
from collections import defaultdict

with open("aves_no_padded_bbox_output_700.json", "r") as f:
    data = json.load(f)

object_counts = defaultdict(int)

for item in data:
    obj = item["object"]
    object_counts[obj] += 1

for obj, count in object_counts.items():
    print(f"{obj}: {count} images")
    
import os
import json
import shutil
import uuid
import random
import hashlib
from collections import defaultdict
from tqdm import tqdm

def round_bbox(bbox):
    return [round(x, 2) for x in bbox]

def ensure_dirs():
    os.makedirs("Inat_aves_no/train/images", exist_ok=True)
    os.makedirs("Inat_aves_no/val/images", exist_ok=True)

def get_md5(image_path):
    try:
        with open(image_path, 'rb') as f:
            return hashlib.md5(f.read()).hexdigest()
    except Exception:
        return None

def load_input(json_path):
    if not os.path.exists(json_path):
        print(f"[ERROR] JSON file not found: {json_path}")
        exit(1)
    with open(json_path, "r") as f:
        return json.load(f)

def process_and_save(json_path):
    print(f"📂 Loading from {json_path}")
    ensure_dirs()
    input_data = load_input(json_path)

    seen_md5_object = set()
    object_groups = defaultdict(list)

    print("🔍 Deduplicating and grouping by object...")
    for item in input_data:
        image_path = item["image_path"]
        if not os.path.exists(image_path):
            continue
        md5 = get_md5(image_path)
        if not md5:
            continue
        key = (md5, item["object"])
        if key in seen_md5_object:
            continue
        seen_md5_object.add(key)
        item["md5"] = md5
        object_groups[item["object"]].append(item)

    train_items, val_items = [], []
    random.seed(42)
    for obj, items in object_groups.items():
        random.shuffle(items)
        split_idx = int(len(items) * 0.8)
        train_items.extend(items[:split_idx])
        val_items.extend(items[split_idx:])

    print("🔁 Resolving MD5 conflicts with random assignment...")
    md5_to_items = defaultdict(list)
    for item in train_items + val_items:
        md5_to_items[item["md5"]].append(item)

    md5_to_split = {}
    random.seed(42)
    for md5 in md5_to_items:
        md5_to_split[md5] = "train" if random.random() < 0.8 else "val"

    final_train = [item for item in train_items + val_items if md5_to_split[item["md5"]] == "train"]
    final_val = [item for item in train_items + val_items if md5_to_split[item["md5"]] == "val"]

    train_data, val_data, val_ans_data = [], [], []
    missing_train_images = 0
    missing_val_images = 0

    print("🚀 Copying train images...")
    for item in tqdm(final_train):
        file_name = os.path.basename(item["image_path"])
        file_id = file_name.split("_")[-1].replace(".jpg", "")
        dst_image_path = f"Inat_aves_no/train/images/{file_id}.jpg"

        if os.path.exists(item["image_path"]):
            shutil.copy(item["image_path"], dst_image_path)
        else:
            print(f"[Missing Train Image] Skipping: {item['image_path']}")
            missing_train_images += 1
            continue

        x1, y1, x2, y2 = item["bbox"]
        area = round((x2 - x1) * (y2 - y1), 4)
        train_data.append({
            "id": file_id,
            "image": f"Inat_aves_no/train/images/{file_id}.jpg",
            "category": "Inat_aves_no",
            "area": area,
            "conversations": [
                {
                    "from": "human",
                    "value": f"<image>\nProvide bounding box coordinate for {item['object']}."
                },
                {
                    "from": "gpt",
                    "value": str(round_bbox(item["bbox"]))
                }
            ]
        })

    print("🚀 Copying val images...")
    for item in tqdm(final_val):
        file_name = os.path.basename(item["image_path"])
        file_id = file_name.split("_")[-1].replace(".jpg", "")
        uuid_str = str(uuid.uuid4())
        rounded_bbox = str(round_bbox(item["bbox"]))
        prompt = f"<image>\nProvide bounding box coordinate for {item['object']}."

        src_path = item["image_path"]
        dst_path = f"Inat_aves_no/val/images/{file_id}.jpg"
        if os.path.exists(src_path):
            shutil.copy(src_path, dst_path)
        else:
            print(f"[Missing Val Image] Skipping: {src_path}")
            missing_val_images += 1
            continue

        x1, y1, x2, y2 = item["bbox"]
        area = round((x2 - x1) * (y2 - y1), 4)

        val_data.append({
            "question_id": uuid_str,
            "image": f"{file_id}.jpg",
            "category": "Inat_aves_no",
            "text": prompt,
            "id": uuid_str
        })

        val_ans_data.append({
            "question_id": uuid_str,
            "prompt": prompt,
            "text": rounded_bbox,
            "area": area,
            "answer_id": None,
            "model_id": None,
        })

    with open("Inat_aves_no/train/train.json", "w") as f:
        json.dump(train_data, f, indent=2)
    with open("Inat_aves_no/val/val.json", "w") as f:
        json.dump(val_data, f, indent=2)
    with open("Inat_aves_no/val/val_ans.json", "w") as f:
        json.dump(val_ans_data, f, indent=2)

    print(f"\n✅ Done.")
    print(f"📊 Total Train: {len(train_data)} | Missing: {missing_train_images}")
    print(f"📊 Total Val:   {len(val_data)} | Missing: {missing_val_images}")

if __name__ == "__main__":
    json_file = "aves_no_padded_bbox_output_700.json"
    process_and_save(json_file)

#%%

import json
import os
import hashlib
from collections import defaultdict

def get_md5(image_path):
    try:
        with open(image_path, 'rb') as f:
            return hashlib.md5(f.read()).hexdigest()
    except Exception:
        return None

def extract_object_name(text):
    if "for " in text:
        return text.split("for ")[-1].strip().strip(".").lower()
    return text.strip().lower()

def build_hash_object_map(json_path, image_dir, is_train=True):
    hash_map = defaultdict(list)
    with open(json_path) as f:
        data = json.load(f)

    for item in data:
        img_path = os.path.join(image_dir, item["image"])
        h = get_md5(img_path)
        if not h:
            continue
        try:
            if is_train:
                text = item["conversations"][0]["value"]
            else:
                text = item["text"]
            obj = extract_object_name(text)
            hash_map[h].append((obj, img_path))
        except Exception:
            continue
    return hash_map

def check_leakage(train_map, val_map):
    leak_count = 0
    matched_hashes = set()
    leak_details = []

    for h, val_entries in val_map.items():
        if h not in train_map:
            continue
        for val_obj, val_path in val_entries:
            for train_obj, train_path in train_map[h]:
                if val_obj == train_obj:
                    leak_count += 1
                    matched_hashes.add(h)
                    leak_details.append({
                        "md5": h,
                        "train_image": train_path,
                        "val_image": val_path,
                        "object": val_obj
                    })
                    break

    print("\n✅ Leakage check completed:")
    print(f"📌 {leak_count} validation samples are duplicated in the training set (same image + same object).")
    print(f"🖼️ These correspond to {len(matched_hashes)} unique image contents.")

    with open("leakage_same_object.json", "w") as f:
        json.dump(leak_details, f, indent=2)

def check_same_image_diff_object(train_map, val_map):
    conflict_count = 0
    conflict_details = []

    for h, val_entries in val_map.items():
        if h not in train_map:
            continue
        for val_obj, val_path in val_entries:
            for train_obj, train_path in train_map[h]:
                if val_obj != train_obj:
                    conflict_count += 1
                    conflict_details.append({
                        "md5": h,
                        "train_image": train_path,
                        "train_object": train_obj,
                        "val_image": val_path,
                        "val_object": val_obj
                    })
                    break

    print("\n⚠️ Conflict check completed:")
    print(f"🔀 {conflict_count} validation samples have the same image as training but different object names.")

    with open("conflict_different_object.json", "w") as f:
        json.dump(conflict_details, f, indent=2)

train_json_path = "Inat_aves_no/train/train.json"
val_json_path = "Inat_aves_no/val/val.json"
train_image_dir = ""
val_image_dir = "Inat_aves_no/val/images"

print("📂 Building hash-object map for training set...")
train_map = build_hash_object_map(train_json_path, train_image_dir, is_train=True)

print("📂 Building hash-object map for validation set...")
val_map = build_hash_object_map(val_json_path, val_image_dir, is_train=False)

check_leakage(train_map, val_map)
check_same_image_diff_object(train_map, val_map)
#%%
import json
import re
from collections import defaultdict
import matplotlib.pyplot as plt

json_path = "Inat_aves_no/train/train.json"  
with open(json_path, "r") as f:
    data = json.load(f)

object_counter = defaultdict(int)
for entry in data:
    for conv in entry.get("conversations", []):
        if conv["from"] == "human":
            match = re.search(r"Provide bounding box coordinate for (.+?)\.", conv["value"])
            if match:
                obj_name = match.group(1).strip().lower()
                object_counter[obj_name] += 1

objects = list(object_counter.keys())
counts = list(object_counter.values())

plt.figure(figsize=(14, 6))
plt.bar(objects, counts)
plt.xticks(rotation=90, ha="right")
plt.xlabel("Object Name")
plt.ylabel("Number of Images")
plt.title("Histogram of Object Occurrences in train.json")
plt.tight_layout()
plt.show()
#%%
import json
import re
from collections import defaultdict
import matplotlib.pyplot as plt

json_path = "Inat_aves_no/val/val_ans.json"  
with open(json_path, "r") as f:
    data = json.load(f)

object_counter = defaultdict(int)
for item in data:
    prompt = item.get("prompt", "")
    match = re.search(r"Provide bounding box coordinate for (.+?)\.", prompt)
    if match:
        obj_name = match.group(1).strip().lower()
        object_counter[obj_name] += 1

objects = list(object_counter.keys())
counts = list(object_counter.values())

plt.figure(figsize=(14, 6))
plt.bar(objects, counts)
plt.xticks(rotation=90, ha="right")
plt.xlabel("Object Name")
plt.ylabel("Number of Images")
plt.title("Histogram of Object Occurrences in val.json")
plt.tight_layout()
plt.show()

# %%
