import os
import json

# — Configuration —
# Input JSON with all validation records
IN_JSON         = "/fs/scratch/PAS2099/Jiacheng/place365/output/places365_val_data.json"
# Output JSON for fine-grained records
OUT_FINE_JSON   = "/fs/scratch/PAS2099/Jiacheng/place365/output/places365_val_finegrained.json"
# — End Configuration —

def main():
    # 1. Load the full validation data
    with open(IN_JSON, "r") as f:
        data = json.load(f)

    # 2. Build a mapping from prefix -> set of suffixes for labels of the form "prefix/suffix"
    prefix_to_suffixes = {}
    for record in data:
        label = record["label"]
        if "/" in label:
            prefix, suffix = label.split("/", 1)
            prefix_to_suffixes.setdefault(prefix, set()).add(suffix)

    # 3. Identify which prefixes are fine-grained (i.e., have more than one distinct suffix)
    fine_prefixes = {p for p, suffixes in prefix_to_suffixes.items() if len(suffixes) > 1}

    # 4. Collect all full "prefix/suffix" labels that are fine-grained
    fine_labels = {
        f"{p}/{suf}"
        for p in fine_prefixes
        for suf in prefix_to_suffixes[p]
    }

    # 5. Filter records whose label is in the fine-grained set
    fine_records = [rec for rec in data if rec["label"] in fine_labels]

    # 6. Write the fine-grained records to a separate JSON file
    os.makedirs(os.path.dirname(OUT_FINE_JSON), exist_ok=True)
    with open(OUT_FINE_JSON, "w") as f:
        json.dump(fine_records, f, indent=2, ensure_ascii=False)

    # 7. Print out all detected fine-grained attributes
    print("Detected fine-grained categories:")
    for prefix in sorted(fine_prefixes):
        suffix_list = sorted(prefix_to_suffixes[prefix])
        print(f"  {prefix}: {', '.join(suffix_list)}")

    print(f"\nTotal records in fine-grained JSON: {len(fine_records)}")
    print(f"Saved to: {OUT_FINE_JSON}")

if __name__ == "__main__":
    main()
