"""Aggregate number-edit metrics from per-sample scored outputs.

Consumes the schema produced by score_number_edit_v2.py:
  {judgment: "faithful"|"reverted"|"other",
   found_in_fl: bool, value_in_fl: str, ...}
"""

from __future__ import annotations

import argparse
import json
from collections import Counter, defaultdict
from pathlib import Path
import sys
from typing import Dict

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from number_edit.common import load_jsonl


def main() -> None:
    parser = argparse.ArgumentParser(description="Aggregate number-edit benchmark scores")
    parser.add_argument("--input", required=True, help="Per-sample scored JSONL")
    parser.add_argument("--output", default="", help="Optional output JSON for summary")
    args = parser.parse_args()

    rows = load_jsonl(args.input)

    summary: Dict[str, Dict[str, float]] = {}
    by_edit = defaultdict(list)
    for row in rows:
        by_edit[row.get("edit_type", "unknown")].append(row)

    def _metrics(group):
        counts = Counter(r.get("judgment", "other") for r in group)
        total = len(group)
        if total == 0:
            return None
        fr = counts.get("faithful", 0)
        rr = counts.get("reverted", 0)
        our = counts.get("other", 0)
        return {
            "count": total,
            "FR": round(fr / total, 4),
            "RR": round(rr / total, 4),
            "OUR": round(our / total, 4),
            "counts": {"faithful": fr, "reverted": rr, "other": our},
        }

    for edit_type, group in by_edit.items():
        m = _metrics(group)
        if m is not None:
            summary[edit_type] = m

    print(json.dumps(summary, indent=2))

    if args.output:
        out_path = Path(args.output)
        out_path.parent.mkdir(parents=True, exist_ok=True)
        with out_path.open("w", encoding="utf-8") as f:
            json.dump(summary, f, ensure_ascii=False, indent=2)
        print(f"Saved summary to {out_path}")


if __name__ == "__main__":
    main()
