#%%
import json
from collections import defaultdict

def process(json_path):
    with open(json_path, "r") as f:
        d = json.load(f)

    id_to_path = {
        img["id"]: "/".join(img["coco_url"].split("/")[-2:])
        for img in d["images"]
        if "coco_url" in img
    }

    cat = {c["id"]: c["name"] for c in d["categories"]}
    out = {}

    for a in d["annotations"]:
        img_id = a["image_id"]
        if img_id not in id_to_path:
            continue
        img = id_to_path[img_id]
        name = cat[a["category_id"]]
        if img not in out:
            out[img] = defaultdict(int)
        out[img][name] += 1

    return {k: dict(v) for k, v in out.items()}

train = process("downloads/lvis_v1_train.json")
val = process("downloads/lvis_v1_val.json")
combined = {**train, **val}

with open("per_image_object_counts_combined.json", "w") as f:
    json.dump(combined, f, indent=2)

#%%
import json
from collections import defaultdict

with open("per_image_object_counts_combined.json", "r") as f:
    data = json.load(f)

result = defaultdict(lambda: defaultdict(lambda: {"count": 0, "image_ids": []}))

for image_path, objects in data.items():
    for obj_name, count in objects.items():
        count_str = str(count)
        result[obj_name][count_str]["count"] += 1
        result[obj_name][count_str]["image_ids"].append(image_path)

with open("object_grouped_by_count_and_image.json", "w") as f:
    json.dump(result, f, indent=2)


# %%
import json
import random

random.seed(42)
def generate_filtered_summary_json(data, min_threshold=15, max_threshold=30, min_valid_counts=5, max_valid_count=5, count_range=(1, 40)):
    result = {}
    grand_total = 0
    object_count = 0

    for obj, count_dict in data.items():
        valid_count_items = {}

        for count_str, info in count_dict.items():
            try:
                count_i = int(count_str)
                count = int(info["count"])
            except ValueError:
                continue

            if count_range[0] <= count_i <= count_range[1] and count >= min_threshold:
                images_to_add = min(count, max_threshold)
                valid_count_items[count_str] = images_to_add

        if len(valid_count_items) > min_valid_counts:
            sampled_keys = random.sample(list(valid_count_items.keys()), min_valid_counts)
            sampled_items = {k: valid_count_items[k] for k in sampled_keys}
            obj_total = sum(sampled_items.values())
            sampled_items["total"] = obj_total
            result[obj] = sampled_items
            grand_total += obj_total
            object_count += 1

        elif len(valid_count_items) == min_valid_counts:
            obj_total = sum(valid_count_items.values())
            valid_count_items["total"] = obj_total
            result[obj] = valid_count_items
            grand_total += obj_total
            object_count += 1

    result["_total_images"] = grand_total
    result["_total_objects"] = object_count
    return result



with open("object_grouped_by_count_and_image.json", "r") as f:
    data = json.load(f)

summary = generate_filtered_summary_json(data)

with open("summary_output_5.json", "w") as f:
    json.dump(summary, f, indent=2)
# %%

#   "_total_images": 6121,
#   "_total_objects": 45
import json
import matplotlib.pyplot as plt

with open("summary_output_5.json", "r") as f:
    data = json.load(f)

# Object vs Total Image Count
object_names = list(data.keys())
total_counts = [data[obj].get("total", 0) for obj in object_names]

plt.figure(figsize=(12, 6))
plt.bar(object_names, total_counts)
plt.xticks(rotation=90)
plt.ylabel("Total Image Count")
plt.title("Object vs Total Image Count")
plt.tight_layout()
plt.show()

distribution = {}
for obj_counts in data.values():
    for count, num in obj_counts.items():
        if count == "total":
            continue
        distribution[count] = distribution.get(count, 0) + num

x = sorted(int(k) for k in distribution.keys())
y = [distribution[str(k)] for k in x]

plt.figure(figsize=(8, 6))
plt.bar(x, y)
plt.xlabel("Object Count per Image")
plt.ylabel("Number of Images")
plt.title("Object Count vs Image Count")
plt.tight_layout()
plt.show()


#%%
import os
import json
import csv
import random
from tqdm import tqdm

random.seed(42)

def generate_filtered_summary_json(data, min_threshold=15, max_threshold=30, min_valid_counts=5, max_valid_count=5, count_range=(1, 40)):
    result = {}
    grand_total = 0
    object_count = 0

    for obj, count_dict in data.items():
        valid_count_items = {}

        for count_str, info in count_dict.items():
            try:
                count_i = int(count_str)
                count = int(info["count"])
            except ValueError:
                continue

            if count_range[0] <= count_i <= count_range[1] and count >= min_threshold:
                images_to_add = min(count, max_threshold)
                valid_count_items[count_str] = images_to_add

        if len(valid_count_items) > min_valid_counts:
            sampled_keys = random.sample(list(valid_count_items.keys()), min_valid_counts)
            sampled_items = {k: valid_count_items[k] for k in sampled_keys}
            obj_total = sum(sampled_items.values())
            sampled_items["total"] = obj_total
            result[obj] = sampled_items
            grand_total += obj_total
            object_count += 1

        elif len(valid_count_items) == min_valid_counts:
            obj_total = sum(valid_count_items.values())
            valid_count_items["total"] = obj_total
            result[obj] = valid_count_items
            grand_total += obj_total
            object_count += 1

    result["_total_images"] = grand_total
    result["_total_objects"] = object_count
    return result


def compute_mean_areas(image_ids_with_object, lvis_train, lvis_val, cat_name_to_id):
    mean_areas = []
    for item in tqdm(image_ids_with_object, desc="Calculating areas"):
        img_path = item["image"]
        obj = item["object"]
        count = item["count"]
        image_id = int(img_path.split("/")[-1].replace(".jpg", ""))
        source = lvis_train if "train2017" in img_path else lvis_val
        img_info = next((img for img in source["images"] if img["id"] == image_id), None)
        if not img_info:
            continue
        w, h = img_info["width"], img_info["height"]
        cat_id = cat_name_to_id[obj]
        anns = [ann for ann in source["annotations"] if ann["image_id"] == image_id and ann["category_id"] == cat_id]
        for ann in anns:
            norm_area = ann["area"] / (w * h)
            mean_areas.append({
                "image": img_path,
                "object": obj,
                "count": count,
                "normalized_area": norm_area
            })
    return mean_areas


# Load input data
with open("object_grouped_by_count_and_image.json", "r") as f:
    all_data = json.load(f)

filtered_summary = generate_filtered_summary_json(all_data)

with open("downloads/lvis_v1_train.json", "r") as f:
    lvis_train = json.load(f)
with open("downloads/lvis_v1_val.json", "r") as f:
    lvis_val = json.load(f)

cat_train = {c["name"]: c["id"] for c in lvis_train["categories"]}
cat_val = {c["name"]: c["id"] for c in lvis_val["categories"]}
assert cat_train == cat_val
cat_name_to_id = cat_train

used_images = []
print(len(filtered_summary))
for obj, count_dict in tqdm(filtered_summary.items(), desc="Sampling images"):
    if obj.startswith("_"):
        continue
    for count_str, count_value in count_dict.items():
        if count_str == "total":
            continue
        try:
            all_image_ids = all_data[obj][count_str]["image_ids"]
        except KeyError:
            continue
        sampled = random.sample(all_image_ids, min(count_value, len(all_image_ids)))
        for img_path in sampled:
            used_images.append({
                "image": img_path,
                "object": obj,
                "count": int(count_str)
            })

# Save sampled image list
with open("merged_used_image_list.json", "w") as f:
    json.dump(used_images, f, indent=2)

# Compute area
area_records = compute_mean_areas(used_images, lvis_train, lvis_val, cat_name_to_id)
with open("normalized_area_records.csv", "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=["image", "object", "count", "normalized_area"])
    writer.writeheader()
    for row in area_records:
        writer.writerow(row)


#%%
import csv
from collections import defaultdict

input_file = "normalized_area_records.csv"
output_file = "sorted_object_by_mean_area.csv"

area_sum = defaultdict(float)
count = defaultdict(int)

with open(input_file, "r") as f:
    reader = csv.DictReader(f)
    for row in reader:
        obj = row["object"]
        area = float(row["normalized_area"])
        area_sum[obj] += area
        count[obj] += 1

mean_areas = [(obj, area_sum[obj] / count[obj]) for obj in area_sum]
mean_areas.sort(key=lambda x: x[1], reverse=True)

with open(output_file, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["object", "mean_normalized_area"])
    for obj, mean in mean_areas:
        writer.writerow([obj, f"{mean:.6f}"])


#%%
import json
import csv

json_file = "summary_output_5.json"
csv_file = "sorted_object_by_mean_area.csv"
output_file = "top_objects_up_to_4000.json"

with open(json_file, "r") as f:
    obj_to_total_count = json.load(f)

area_sorted_objects = []
with open(csv_file, "r") as f:
    reader = csv.DictReader(f)
    for row in reader:
        obj = row["object"]
        mean_area = float(row["mean_normalized_area"])
        area_sorted_objects.append((obj, mean_area))

area_sorted_objects.sort(key=lambda x: x[1], reverse=True)

selected = {}
image_total = 0
target = 4000

for obj, area in area_sorted_objects:
    if obj not in obj_to_total_count:
        continue
    total_images = obj_to_total_count[obj].get("total", 0)
    if image_total + total_images > target:
        break
    selected[obj] = obj_to_total_count[obj]
    image_total += total_images

# Output result
with open(output_file, "w") as f:
    json.dump(selected, f, indent=2)

with open("selected_objects.txt", "w") as f:
    for obj in selected:
        f.write(obj + "\n")

print(f"Selected {len(selected)} objects with total {image_total} images.")


#%%
import json
import uuid
import os
import shutil
import random
from collections import defaultdict

random.seed(42)

json_path = "merged_used_image_list.json"
selected_txt = "selected_objects.txt"

train_json = []
val_json = []
val_ans_json = []
missing_images = []

output_train_img_dir = "lvis/train/images"
output_val_img_dir = "lvis/val/images"
os.makedirs(output_train_img_dir, exist_ok=True)
os.makedirs(output_val_img_dir, exist_ok=True)

def find_and_copy_image(img_path, dest_folder, new_name=None):
    src_path = os.path.join("downloads", img_path) 
    dst_name = new_name if new_name else os.path.basename(img_path)
    dst_path = os.path.join(dest_folder, dst_name)
    if os.path.exists(src_path):
        shutil.copy2(src_path, dst_path)
        return dst_name
    return None

# === Load data ===
with open(json_path, "r") as f:
    all_data = json.load(f)

with open(selected_txt, "r") as f:
    selected_objects = set(line.strip() for line in f if line.strip())

# === Group by (object, count)
grouped = defaultdict(list)
for item in all_data:
    obj = item["object"]
    count = item["count"]
    img_path = item["image"]
    if obj in selected_objects:
        grouped[(obj, count)].append(img_path)

# === For each (object, count), do 8:2 split
train_items = []
val_items = []

for (obj, count), img_list in grouped.items():
    random.shuffle(img_list)
    split_idx = int(len(img_list) * 0.8)
    for img in img_list[:split_idx]:
        train_items.append((obj, count, img))
    for img in img_list[split_idx:]:
        val_items.append((obj, count, img))

# === Process train set
for obj, count, img_path in train_items:
    uid = str(uuid.uuid4())
    copied_name = find_and_copy_image(img_path, output_train_img_dir, new_name=f"{uid}.jpg")
    if copied_name:
        train_json.append({
            "id": uid,
            "image": f"lvis/train/images/{copied_name}",
            "conversations": [
                {"from": "human", "value": f"<image>\nhow many {obj} are there in the image?"},
                {"from": "gpt", "value": str(count)}
            ]
        })
    else:
        missing_images.append(img_path)

# === Process val set
for obj, count, img_path in val_items:
    copied_name = find_and_copy_image(img_path, output_val_img_dir, new_name=os.path.basename(img_path))
    if copied_name:
        qid = str(uuid.uuid4())
        question = f"<image>\nhow many {obj} are there in the image?"
        val_json.append({
            "question_id": qid,
            "image": f"lvis/val/images/{os.path.basename(img_path)}",
            "category": "default",
            "text": question,
            "id": qid
        })
        val_ans_json.append({
            "question_id": qid,
            "prompt": question,
            "text": str(count),
            "answer_id": None,
            "model_id": None,
            "metadata": {}
        })
    else:
        missing_images.append(img_path)

# === Save outputs
os.makedirs("lvis/train", exist_ok=True)
os.makedirs("lvis/val", exist_ok=True)

with open("lvis/train/train.json", "w") as f:
    json.dump(train_json, f, indent=2)

with open("lvis/val/val.json", "w") as f:
    json.dump(val_json, f, indent=2)

with open("lvis/val/val_ans.json", "w") as f:
    json.dump(val_ans_json, f, indent=2)

with open("lvis/missing_images.txt", "w") as f:
    for mid in missing_images:
        f.write(mid + "\n")

print(f"✅ Total train images: {len(train_json)}")
print(f"✅ Total val images: {len(val_json)}")
print(f"🚫 Missing images: {len(missing_images)}")

