import json
from collections import defaultdict

# === 所需的 8 个朝向选项 ===
DIRECTIONS = [
    "front", "front right", "right", "back right",
    "back", "back left", "left", "front left"
]

# === 读取 benchmark.json ===
with open("/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/output/general_complex_8_direction_filtered.json", "r") as f:
    data = json.load(f)

# === 构建数据结构 ===
object_dir_to_image_to_items = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
object_dir_image_counts = defaultdict(lambda: defaultdict(set))

for item in data:
    if item["type"] == "general_complex":
        obj = item["category_name"].strip().lower()
        label = item["label"].strip().lower()
        image = item["image"].strip()

        object_dir_image_counts[obj][label].add(image)
        object_dir_to_image_to_items[obj][label][image].append(item)

# === 收集符合条件的条目
filtered_data = []
valid_objects = []

for obj in object_dir_image_counts:
    labels = object_dir_image_counts[obj]

    # 必须包含全部8个方向
    if not all(direction in labels for direction in DIRECTIONS):
        continue

    # 每个方向至少20张图像
    if not all(len(labels[direction]) >= 20 for direction in DIRECTIONS):
        continue

    # 通过检查，保留 object
    valid_objects.append(obj)

    for direction in DIRECTIONS:
        images = list(object_dir_image_counts[obj][direction])

        # 如果超过80张图像，只保留前80张
        if len(images) > 80:
            images = images[:80]  # 可改为 random.sample(images, 80) 随机选择

        for image in images:
            filtered_data.extend(object_dir_to_image_to_items[obj][direction][image])

# === 保存结果
output_path = "/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/output/general_complex_8dir_20min_80max.json"
with open(output_path, "w") as f:
    json.dump(filtered_data, f, indent=2)

print(f"✅ Done! Saved {len(filtered_data)} filtered entries to:\n{output_path}")
print(f"🎯 Total valid objects: {len(valid_objects)}")
