"""Stage 1.5: repair Gemini's char offsets using context-based anchoring.

Mirror of number_edit/fix_offsets.py but for the symbol candidate format.
Uses the shared `fix_offset_by_context` helper from number_edit.common.
"""

from __future__ import annotations

import argparse
import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from number_edit.common import fix_offset_by_context
from symbol_edit.common import load_jsonl, write_jsonl


def main() -> None:
    parser = argparse.ArgumentParser(description="Fix symbol char offsets via context anchoring")
    parser.add_argument("--input", required=True, help="Original dataset JSONL (for text lookup)")
    parser.add_argument("--labels", default="./symbol_edit/data/labeled_symbols.jsonl")
    parser.add_argument("--output", default="./symbol_edit/data/labeled_symbols_fixed.jsonl")
    args = parser.parse_args()

    problems = {row["name"]: row for row in load_jsonl(args.input)}
    labeled = load_jsonl(args.labels)

    fixed_rows = []
    n_in = n_out = n_dropped = n_already_ok = n_recovered = 0

    for row in labeled:
        name = row.get("problem_name", "")
        p = problems.get(name)
        if p is None:
            fixed_rows.append(row)
            continue

        stmt = str(p.get("informal_statement", "") or "")
        proof = str(p.get("informal_proof", "") or "")

        kept = []
        for cand in row.get("symbols", []) or []:
            n_in += 1
            source = cand.get("source", "")
            text = stmt if source == "statement" else proof if source == "proof" else ""
            if not text:
                n_dropped += 1
                continue

            symbol = str(cand.get("symbol", "")).strip()
            ctx = cand.get("context", "")
            s = int(cand.get("char_offset_start", -1))
            e = int(cand.get("char_offset_end", -1))

            was_ok = 0 <= s < e <= len(text) and text[s:e] == symbol
            result = fix_offset_by_context(text, symbol, ctx, s, e)
            if result is None:
                n_dropped += 1
                continue
            new_s, new_e = result
            cand = dict(cand)
            cand["char_offset_start"] = new_s
            cand["char_offset_end"] = new_e
            if was_ok:
                n_already_ok += 1
            else:
                n_recovered += 1
            kept.append(cand)
            n_out += 1

        new_row = dict(row)
        new_row["symbols"] = kept
        fixed_rows.append(new_row)

    write_jsonl(args.output, fixed_rows)

    print("=== Offset Fix Summary ===")
    print(f"  candidates in:         {n_in}")
    print(f"  candidates out:        {n_out}")
    print(f"  already correct:       {n_already_ok}")
    print(f"  recovered via context: {n_recovered}")
    print(f"  dropped (unrecoverable): {n_dropped}")
    print(f"  fix rate:              {100 * n_out / max(n_in, 1):.1f}%")
    print(f"  output: {args.output}")


if __name__ == "__main__":
    main()
