#%%
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

base_filename = "4" 

image_path = f"/fs/scratch/PAS2099/Herz_2/VFM/Dataset/counting/FSC/images_384_VarV2/{base_filename}.jpg"
density_map_path = f"/fs/scratch/PAS2099/Herz_2/VFM/Dataset/counting/FSC/gt_density_map_adaptive_384_VarV2/{base_filename}.npy"
img = Image.open(image_path).convert("RGB")
img = np.array(img)
density_map = np.load(density_map_path)
estimated_count = density_map.sum()
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.imshow(img)
plt.axis('off')
plt.title("Original Image")

plt.subplot(1, 2, 2)
plt.imshow(img)
plt.imshow(density_map, cmap='jet', alpha=0.5)
plt.axis('off')
plt.title(f"Image + Density Map\nEstimated Count: {estimated_count:.2f}")

plt.tight_layout()
plt.show()


#%%
import os
import json
import numpy as np
from collections import defaultdict

label_file = "ann.txt"
image_dir = "/fs/scratch/PAS2099/Herz_2/VFM/Dataset/counting/FSC/images_384_VarV2"
density_dir = "/fs/scratch/PAS2099/Herz_2/VFM/Dataset/counting/FSC/gt_density_map_adaptive_384_VarV2"

summary = defaultdict(lambda: defaultdict(lambda: {"count": 0, "image_ids": []}))

with open(label_file, "r") as f:
    lines = f.readlines()

for line in lines:
    img_name, obj = line.strip().split("\t")
    base_name = os.path.splitext(img_name)[0]
    image_path = os.path.join(image_dir, img_name)
    density_path = os.path.join(density_dir, f"{base_name}.npy")

    if not os.path.exists(image_path) or not os.path.exists(density_path):
        continue

    count = int(round(np.load(density_path).sum()))
    count_str = str(count)

    summary[obj][count_str]["count"] += 1
    summary[obj][count_str]["image_ids"].append(img_name)

total_objects = len(summary)
total_images = sum(
    len(count_data["image_ids"])
    for obj_data in summary.values()
    for count_data in obj_data.values()
)

print(f"- Total Unique Object Names: {total_objects}")
print(f"- Total Images Counted: {total_images}")

with open("summary_output_fsc_with_id.json", "w") as f:
    json.dump(summary, f, indent=2)

#%%
import json

def generate_filtered_summary_json(data, min_threshold=6, max_threshold=12, min_valid_counts=4, count_range=(1, 40)):
    result = {}
    grand_total = 0
    object_count = 0

    for obj, count_dict in data.items():
        obj_info = {}
        obj_total = 0
        valid_count_items = {}

        for count_str, info in count_dict.items():
            try:
                count_i = int(count_str)
                count = int(info["count"])
            except ValueError:
                continue

            if count_range[0] <= count_i <= count_range[1] and count >= min_threshold:
                images_to_add = min(count, max_threshold)
                valid_count_items[count_str] = images_to_add
                obj_total += images_to_add

        if len(valid_count_items) >= min_valid_counts:
            valid_count_items["total"] = obj_total
            result[obj] = valid_count_items
            grand_total += obj_total
            object_count += 1

    result["_total_images"] = grand_total
    result["_total_objects"] = object_count
    return result

with open("summary_output_fsc_with_id.json", "r") as f:
    data = json.load(f)

summary = generate_filtered_summary_json(data)

with open("summary_output_5.json", "w") as f:
    json.dump(summary, f, indent=2)

#%%
import json
import os
import random
import uuid
import shutil

# Config
root_dir = "fsc"
img_source_dir = "images_384_VarV2"
output_train_img_dir = os.path.join(root_dir, "train/images")
output_val_img_dir = os.path.join(root_dir, "val/images")
os.makedirs(output_train_img_dir, exist_ok=True)
os.makedirs(output_val_img_dir, exist_ok=True)

with open("summary_output_fsc_with_id.json", "r") as f:
    all_data = json.load(f)

def generate_filtered_summary_json(data, min_threshold=6, max_threshold=12, min_valid_counts=4, count_range=(1, 40)):
    result = {}
    grand_total = 0
    object_count = 0

    for obj, count_dict in data.items():
        obj_total = 0
        valid_count_items = {}

        for count_str, info in count_dict.items():
            if count_str == "total":
                continue
            try:
                count_i = int(count_str)
                count = int(info["count"])
            except ValueError:
                continue

            if count_range[0] <= count_i <= count_range[1] and count >= min_threshold:
                valid_count_items[count_str] = info["image_ids"][:max_threshold]
                obj_total += len(valid_count_items[count_str])

        if len(valid_count_items) >= min_valid_counts:
            valid_count_items["total"] = obj_total
            result[obj] = valid_count_items
            grand_total += obj_total
            object_count += 1

    result["_total_images"] = grand_total
    result["_total_objects"] = object_count
    return result

filtered_data = generate_filtered_summary_json(all_data)

train_json = []
val_json = []
val_ans_json = []
missing_images = []

random.seed(42)

def find_and_copy_image(img_id, dest_folder, new_name=None):
    fname = f"{img_id}.jpg"
    src_path = os.path.join(img_source_dir, fname)
    if os.path.exists(src_path):
        dst_name = new_name if new_name else fname
        dst_path = os.path.join(dest_folder, dst_name)
        if not os.path.exists(dst_path):
            shutil.copy2(src_path, dst_path)
        return dst_name
    return None

# === Object-Count based 8/2 split ===
for obj, count_dict in filtered_data.items():
    if obj.startswith("_"):
        continue

    for count_str, image_ids in count_dict.items():
        if count_str == "total":
            continue
        count = int(count_str)
        image_ids = image_ids if isinstance(image_ids, list) else image_ids.get("image_ids", [])
        sampled = random.sample(image_ids, min(len(image_ids), 12))

        random.shuffle(sampled)
        split_idx = int(len(sampled) * 0.8)
        train_ids = sampled[:split_idx]
        val_ids = sampled[split_idx:]

        for img_id in train_ids:
            uid = str(uuid.uuid4())
            copied_name = find_and_copy_image(img_id, output_train_img_dir, new_name=f"{uid}.jpg")
            if copied_name:
                train_json.append({
                    "id": uid,
                    "image": f"fsc/train/images/{copied_name}",
                    "conversations": [
                        {
                            "from": "human",
                            "value": f"<image>\nhow many {obj} are there in the image?"
                        },
                        {
                            "from": "gpt",
                            "value": str(count)
                        }
                    ]
                })
            else:
                missing_images.append(str(img_id))

        for img_id in val_ids:
            copied_name = find_and_copy_image(img_id, output_val_img_dir, new_name=f"{img_id}.jpg")
            if copied_name:
                qid = str(uuid.uuid4())
                question = f"<image>\nhow many {obj} are there in the image?"
                val_json.append({
                    "question_id": qid,
                    "image": f"val/images/{img_id}.jpg",
                    "category": "default",
                    "text": question,
                    "id": qid
                })
                val_ans_json.append({
                    "question_id": qid,
                    "prompt": question,
                    "text": str(count),
                    "answer_id": None,
                    "model_id": None,
                    "metadata": {}
                })
            else:
                missing_images.append(str(img_id))

# === Save Outputs ===
os.makedirs(os.path.join(root_dir, "train"), exist_ok=True)
os.makedirs(os.path.join(root_dir, "val"), exist_ok=True)

with open(os.path.join(root_dir, "train/train.json"), "w") as f:
    json.dump(train_json, f, indent=2)

with open(os.path.join(root_dir, "val/val.json"), "w") as f:
    json.dump(val_json, f, indent=2)

with open(os.path.join(root_dir, "val/val_ans.json"), "w") as f:
    json.dump(val_ans_json, f, indent=2)

with open(os.path.join(root_dir, "missing_images.txt"), "w") as f:
    for mid in missing_images:
        f.write(mid + "\n")

print(f"Total filtered object_ids: {filtered_data.get('_total_objects', 'N/A')}")
print(f"Total images used: {filtered_data.get('_total_images', 'N/A')}")
print(f"Missing images: {len(missing_images)}")


#%%
import json
import matplotlib.pyplot as plt

with open("summary_output_5.json", "r") as f:
    summary_data = json.load(f)

summary_data = {k: v for k, v in summary_data.items() if not k.startswith("_")}

object_count_vs_image_count = {}
for obj, counts in summary_data.items():
    for count_str, image_count in counts.items():
        if count_str == "total":
            continue
        count = int(count_str)
        object_count_vs_image_count[count] = object_count_vs_image_count.get(count, 0) + image_count

object_name_vs_image_count = {
    obj: counts["total"] for obj, counts in summary_data.items()
}

plt.figure(figsize=(10, 5))
plt.bar(object_count_vs_image_count.keys(), object_count_vs_image_count.values())
plt.xlabel("Object Count (e.g., 1 object, 2 objects...)")
plt.ylabel("Total Image Count")
plt.title("Object Count vs. Image Count")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# Plot 2: Object name vs image count (top 30 for visibility)
sorted_objects = dict(sorted(object_name_vs_image_count.items(), key=lambda x: x[1], reverse=True)[:30])
plt.figure(figsize=(12, 6))
plt.bar(sorted_objects.keys(), sorted_objects.values())
plt.xlabel("Object Name")
plt.ylabel("Image Count")
plt.title("Top 30 Object Names by Image Count")
plt.xticks(rotation=75, ha='right')
plt.tight_layout()
plt.show()