# import json
# from collections import defaultdict

# # === 所需的 8 个朝向选项 ===
# DIRECTIONS = [
#     "front", "front right", "right", "back right",
#     "back", "back left", "left", "front left"
# ]

# # === 读取 benchmark.json ===
# with open("/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/EgoOrientBench/all_data/EgocentricDataset/train_benchmark/benchmark.json", "r") as f:
#     data = json.load(f)

# # === 构建 object → 出现的方向 + 记录问题条目
# object_to_labels = defaultdict(set)
# object_to_items = defaultdict(list)

# for item in data:
#     if item["type"] == "general_complex":
#         obj = item["category_name"].strip().lower()
#         label = item["label"].strip().lower()

#         object_to_labels[obj].add(label)
#         object_to_items[obj].append(item)

# # === 筛选满足 ≥7 个方向的 object
# valid_objects = [
#     obj for obj, label_set in object_to_labels.items()
#     if len(label_set.intersection(DIRECTIONS)) >= 7
# ]

# # === 收集对应条目
# filtered_data = []
# for obj in valid_objects:
#     filtered_data.extend(object_to_items[obj])

# # === 保存结果
# output_path = "/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/output/general_complex_7dir_only.json"
# with open(output_path, "w") as f:
#     json.dump(filtered_data, f, indent=2)

# print(f"✅ Done! Saved {len(filtered_data)} entries for objects with ≥7 directions.")
# print(f"🎯 Total valid objects: {len(valid_objects)}")

import json
from collections import defaultdict

# === 所需的 8 个朝向选项 ===
DIRECTIONS = [
    "front", "front right", "right", "back right",
    "back", "back left", "left", "front left"
]

# === 读取 benchmark.json ===
with open("/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/EgoOrientBench/all_data/EgocentricDataset/train_benchmark/benchmark.json", "r") as f:
    data = json.load(f)

# === 构建结构 ===
object_dir_to_image_to_items = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
object_dir_image_sets = defaultdict(lambda: defaultdict(set))

for item in data:
    if item["type"] == "general_complex":
        obj = item["category_name"].strip().lower()
        label = item["label"].strip().lower()
        image = item["image"].strip()

        object_dir_image_sets[obj][label].add(image)
        object_dir_to_image_to_items[obj][label][image].append(item)

# === 筛选条件：≥7方向 且 每个方向至少20张图像（超出80则截断）
filtered_data = []
valid_objects = []

for obj in object_dir_image_sets:
    label_map = object_dir_image_sets[obj]

    # object 覆盖了哪些方向
    available_dirs = [d for d in DIRECTIONS if d in label_map]
    if len(available_dirs) < 7:
        continue

    # 每个方向至少20张图像
    if not all(len(label_map[d]) >= 20 for d in available_dirs):
        continue

    valid_objects.append(obj)

    # 收集每个方向的最多80张图像
    for direction in available_dirs:
        image_list = list(label_map[direction])
        if len(image_list) > 80:
            image_list = image_list[:80]  # 可改为 random.sample(image_list, 80)

        for image in image_list:
            filtered_data.extend(object_dir_to_image_to_items[obj][direction][image])

# === 保存结果
output_path = "/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/output/general_complex_7dir_min20_max80_per_direction.json"
with open(output_path, "w") as f:
    json.dump(filtered_data, f, indent=2)

print(f"✅ Done! Saved {len(filtered_data)} entries to:\n{output_path}")
print(f"🎯 Total valid objects (≥7 directions, ≥20 images per direction, max 80): {len(valid_objects)}")
