#https://aistudio.baidu.com/datasetdetail/177162/0
#%%
import tarfile
import os

base_dir = 'dataset'
files_to_extract = ['annotations.tar', 'images_test.tar', 'images_train.tar']
output_dirs = {
    'annotations.tar': 'annotations',
    'images_test.tar': 'images_test',
    'images_train.tar': 'images_train'
}

for filename in files_to_extract:
    tar_path = os.path.join(base_dir, filename)
    extract_dir = os.path.join(base_dir, output_dirs[filename])
    os.makedirs(extract_dir, exist_ok=True)

    try:
        print(f'Extracting {filename} to {extract_dir} ...')
        with tarfile.open(tar_path, 'r', errorlevel=1) as tar:
            for member in tar.getmembers():
                try:
                    tar.extract(member, path=extract_dir)
                except tarfile.ReadError as e:
                    print(f"ReadError on {member.name} in {filename}, skipping.")
        print(f"Finished {filename}")
    except Exception as e:
        print(f"Error opening {filename}: {e}")

print("All extract attempts completed.")

#%%
import json
import os

json_path = 'dataset/train.json'
image_root = 'dataset'  

with open(json_path, 'r') as f:
    data = json.load(f)

missing_images = []
all_images = 0

for item in data.get("annotations", []):
    relative_path = item["name"]  
    image_path = os.path.join(image_root, relative_path)
    all_images += 1

    if not os.path.exists(image_path):
        missing_images.append(image_path)

print(f"\nTotal images listed: {all_images}")
print(f"Missing images: {len(missing_images)}\n")

if missing_images:
    print("Missing image files:")
    for path in missing_images:
        print(path)
else:
    print("All image files exist.")

#%%
import os
import cv2
import json
import matplotlib.pyplot as plt

image_path = "dataset/images/test/467c1798d0c882951ac390b1cd95b335.jpg"
json_path = "dataset/test.json"  
with open(json_path, "r") as f:
    data = json.load(f)

target_ann = None
for ann in data.get("annotations", []):
    if ann.get("image id") == 0:
        target_ann = ann
        break

if not os.path.exists(image_path):
    raise FileNotFoundError(f"Image not found: {image_path}")
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

locations = target_ann["locations"]
for i in range(0, len(locations), 2):
    x, y = locations[i], locations[i + 1]
    cv2.circle(img, (x, y), 4, (255, 0, 0), -1)  

plt.figure(figsize=(12, 8))
plt.imshow(img)
plt.title("Visualized Annotation Points")
plt.axis("off")
plt.show()


#%%
import os
import json
from collections import defaultdict

def load_annotations(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
    return {os.path.basename(ann["name"]): ann for ann in data["annotations"]}

def process_folder(image_folder, annotations_dict, prefix):
    people_count_dict = defaultdict(lambda: {"count": 0, "image_ids": []})

    for fname in os.listdir(image_folder):
        if not fname.lower().endswith(".jpg"):
            continue

        if fname not in annotations_dict:
            continue  # skip images without annotation

        ann = annotations_dict[fname]
        num_people = len(ann["locations"]) // 2
        key = str(num_people)
        image_id = f"{prefix}/{fname.replace('.jpg', '')}"

        people_count_dict[key]["count"] += 1
        people_count_dict[key]["image_ids"].append(image_id)

    return people_count_dict
base_dir = "dataset"
train_img_dir = os.path.join(base_dir, "images/train")
test_img_dir = os.path.join(base_dir, "images/test")
train_json_path = os.path.join(base_dir, "train.json")
test_json_path = os.path.join(base_dir, "test.json")
train_ann = load_annotations(train_json_path)
test_ann = load_annotations(test_json_path)
train_result = process_folder(train_img_dir, train_ann, "train")
test_result = process_folder(test_img_dir, test_ann, "test")
final_summary = defaultdict(lambda: {"count": 0, "image_ids": []})

for d in [train_result, test_result]:
    for people_num, info in d.items():
        final_summary[people_num]["count"] += info["count"]
        final_summary[people_num]["image_ids"].extend(info["image_ids"])

output_path = "people_summary_filtered.json"
with open(output_path, "w") as f:
    json.dump({"people": final_summary}, f, indent=2)

print(f"Done. Saved with 'train/' and 'test/' prefixes to '{output_path}'")


#%%
import json
import matplotlib.pyplot as plt

with open("people_summary_filtered.json", "r") as f:
    data = json.load(f)

x_people_count = []
y_image_count = []

for count_str, info in data["people"].items():
    count = int(count_str)
    if count <= 50:
        x_people_count.append(count)
        y_image_count.append(info["count"])

plt.figure(figsize=(10, 5))
plt.bar(x_people_count, y_image_count)
plt.xlabel("Number of People in Image")
plt.ylabel("Number of Images")
plt.title("How Many Images Have N People (N ≤ 50)")
plt.grid(True)
plt.tight_layout()
plt.show()

total_people = sum(int(k) * v["count"] for k, v in data["people"].items())

plt.figure(figsize=(4, 5))
plt.bar(["people"], [total_people])
plt.ylabel("Total People Across All Images")
plt.title("Total People in All Images")
plt.grid(True, axis="y")
plt.tight_layout()
plt.show()

#%%
import json

def generate_filtered_summary_json(data, min_threshold=25, max_threshold=50, min_valid_counts=5, count_range=(1, 40)):
    result = {}
    grand_total = 0
    object_count = 0

    for obj, count_dict in data.items():
        obj_info = {}
        obj_total = 0
        valid_count_items = {}

        for count_str, info in count_dict.items():
            try:
                count_i = int(count_str)
                count = int(info["count"])
            except ValueError:
                continue

            if count_range[0] <= count_i <= count_range[1] and count >= min_threshold:
                images_to_add = min(count, max_threshold)
                valid_count_items[count_str] = images_to_add
                obj_total += images_to_add

        if len(valid_count_items) >= min_valid_counts:
            valid_count_items["total"] = obj_total
            result[obj] = valid_count_items
            grand_total += obj_total
            object_count += 1

    result["_total_images"] = grand_total
    result["_total_objects"] = object_count
    return result

with open("people_summary_filtered.json", "r") as f:
    data = json.load(f)

summary = generate_filtered_summary_json(data)

with open("summary_output_5.json", "w") as f:
    json.dump(summary, f, indent=2)

#%%
import json
import os
import shutil
import random
import uuid

json_path = "people_summary_filtered.json"
image_dirs = [
    "dataset/images/train",
    "dataset/images/test"
]

root_dir = "people"
train_img_dir = os.path.join(root_dir, "train/images")
val_img_dir = os.path.join(root_dir, "val/images")
os.makedirs(train_img_dir, exist_ok=True)
os.makedirs(val_img_dir, exist_ok=True)

with open(json_path, "r") as f:
    full_data = json.load(f)

all_data = full_data.get("people", {}) 

train_json, val_json, val_ans_json = [], [], []
missing_images = []

def find_and_copy_image(prefixed_id, dest_folder, new_name=None):
    """
    prefixed_id: 'train/abc123', 'test/xyz789'
    """
    split, img_id = prefixed_id.split("/")
    for ext in [".jpg", ".png"]:
        fname = f"{img_id}{ext}"
        for image_dir in image_dirs:
            if image_dir.endswith(split):
                src_path = os.path.join(image_dir, fname)
                if os.path.exists(src_path):
                    dst_name = new_name if new_name else fname
                    dst_path = os.path.join(dest_folder, dst_name)
                    if not os.path.exists(dst_path):
                        shutil.copy2(src_path, dst_path)
                    return dst_name
    print(f"[Missing] {prefixed_id} not found.")
    return None

def generate_filtered_summary_json(data, min_threshold=25, max_threshold=50, min_valid_counts=5, count_range=(1, 40)):
    result = {}
    grand_total = 0
    object_count = 0

    for count_str, info in data.items():
        try:
            count_i = int(count_str)
        except ValueError:
            continue

        if count_range[0] <= count_i <= count_range[1] and info["count"] >= min_threshold:
            selected_ids = info["image_ids"][:max_threshold]
            result[count_str] = {
                "count": len(selected_ids),
                "image_ids": selected_ids
            }
            grand_total += len(selected_ids)
            object_count += 1

    result["_total_images"] = grand_total
    result["_total_objects"] = object_count
    return result

filtered_data = generate_filtered_summary_json(all_data)

random.seed(42)

for count_str, info in filtered_data.items():
    if count_str.startswith("_"):
        continue

    image_ids = info["image_ids"]
    count = int(count_str)

    random.shuffle(image_ids)
    split_idx = int(len(image_ids) * 0.8)
    train_ids = image_ids[:split_idx]
    val_ids = image_ids[split_idx:]

    for img_id in train_ids:
        uid = str(uuid.uuid4())
        copied_name = find_and_copy_image(img_id, train_img_dir, f"{uid}.jpg")
        if copied_name:
            train_json.append({
                "id": uid,
                "image": f"people/train/images/{copied_name}",
                "conversations": [
                    {"from": "human", "value": f"<image>\nhow many people are there in the image?"},
                    {"from": "gpt", "value": str(count)}
                ]
            })
        else:
            missing_images.append(img_id)

    for img_id in val_ids:
        copied_name = find_and_copy_image(img_id, val_img_dir)
        if copied_name:
            qid = str(uuid.uuid4())
            question = f"<image>\nhow many people are there in the image?"
            val_json.append({
                "question_id": qid,
                "image": copied_name,
                "category": "default",
                "text": question,
                "id": qid
            })
            val_ans_json.append({
                "question_id": qid,
                "prompt": question,
                "text": str(count),
                "answer_id": None,
                "model_id": None,
                "metadata": {}
            })
        else:
            missing_images.append(img_id)

os.makedirs(os.path.join(root_dir, "train"), exist_ok=True)
os.makedirs(os.path.join(root_dir, "val"), exist_ok=True)

with open(os.path.join(root_dir, "train/train.json"), "w") as f:
    json.dump(train_json, f, indent=2)

with open(os.path.join(root_dir, "val/val.json"), "w") as f:
    json.dump(val_json, f, indent=2)

with open(os.path.join(root_dir, "val/val_ans.json"), "w") as f:
    json.dump(val_ans_json, f, indent=2)

with open(os.path.join(root_dir, "missing_images.txt"), "w") as f:
    for mid in missing_images:
        f.write(mid + "\n")

print(f"Total filtered people groups: {filtered_data.get('_total_objects', 'N/A')}")
print(f"Total images used: {filtered_data.get('_total_images', 'N/A')}")
print(f"Missing images: {len(missing_images)} (see missing_images.txt)")

