import json
import random
from collections import defaultdict

def build_locality(input_file: str, output_file: str, n_neigh: int = 2, n_dist: int = 2):
    with open(input_file, "r", encoding="utf-8") as f:
        data = json.load(f)  # 假设是 list[dict]

    # 按 subject 分组
    subject_to_entries = defaultdict(list)
    for item in data:
        subject_to_entries[item["subject"]].append(item)

    new_data = []
    for item in data:
        subject = item["subject"]

        # ---- neighborhood ----
        candidates = [other for other in subject_to_entries[subject] if other["case_id"] != item["case_id"]]
        chosen_neigh = random.sample(candidates, min(len(candidates), n_neigh))

        neigh_prompts = [c["prompt"] for c in chosen_neigh]
        neigh_truths = [c["ground_truth"][0] if c["ground_truth"] else "" for c in chosen_neigh]

        # ---- distracting ----
        other_subjects = [other for other in data if other["subject"] != subject]
        chosen_dist = random.sample(other_subjects, min(len(other_subjects), n_dist))

        dist_prompts = [c["prompt"] for c in chosen_dist]
        dist_truths = [c["ground_truth"][0] if c["ground_truth"] else "" for c in chosen_dist]

        # 添加 locality 字段
        item["locality"] = {
            "neighborhood": {
                "prompt": neigh_prompts,
                "ground_truth": neigh_truths
            },
            "distracting": {
                "prompt": dist_prompts,
                "ground_truth": dist_truths
            }
        }
        new_data.append(item)

    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(new_data, f, ensure_ascii=False, indent=4)

    print(f"处理完成，结果保存到 {output_file}")


# 使用示例
build_locality("description_editing.json", "description_editing.json")
