"""Aggregate symbol-edit metrics from per-sample scored outputs."""

from __future__ import annotations

import argparse
import json
from collections import Counter, defaultdict
from pathlib import Path
import sys
from typing import Dict

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from symbol_edit.common import load_jsonl


def main() -> None:
    parser = argparse.ArgumentParser(description="Aggregate symbol-edit benchmark scores")
    parser.add_argument("--input", required=True, help="Per-sample scored JSONL")
    parser.add_argument("--output", default="", help="Optional output JSON for summary")
    parser.add_argument(
        "--break_by_family",
        action="store_true",
        help="Also report FR/RR/OUR broken down by symbol family (relation vs operator)",
    )
    args = parser.parse_args()

    rows = load_jsonl(args.input)

    summary: Dict[str, Dict[str, float]] = {}
    by_edit = defaultdict(list)
    for row in rows:
        by_edit[row["edit_type"]].append(row)

    def _metrics(group):
        counts = Counter(r.get("judgment", "other") for r in group)
        total = len(group)
        if total == 0:
            return None
        fr = counts.get("faithful", 0)
        rr = counts.get("reverted", 0)
        our = counts.get("other", 0)
        return {
            "count": total,
            "FR": round(fr / total, 4),
            "RR": round(rr / total, 4),
            "OUR": round(our / total, 4),
            "counts": {"faithful": fr, "reverted": rr, "other": our},
        }

    for edit_type, group in by_edit.items():
        m = _metrics(group)
        if m is None:
            continue
        summary[edit_type] = m
        if args.break_by_family:
            by_family = defaultdict(list)
            for r in group:
                by_family[r.get("family") or "unknown"].append(r)
            summary[edit_type]["by_family"] = {
                fam: _metrics(grp) for fam, grp in by_family.items()
            }

    print(json.dumps(summary, indent=2))

    if args.output:
        out_path = Path(args.output)
        out_path.parent.mkdir(parents=True, exist_ok=True)
        with out_path.open("w", encoding="utf-8") as f:
            json.dump(summary, f, ensure_ascii=False, indent=2)
        print(f"Saved summary to {out_path}")


if __name__ == "__main__":
    main()
