"""Stage 2: filter labeled numbers and randomly pick one per source.

Reads labeled_numbers.jsonl (output of label_numeric_roles.py, containing a
flat list of candidates per problem) and writes selected_numbers.jsonl in the
format consumed by build_number_edit_unsound.py:

  {"problem_name": ...,
   "statement_edit": {old_value, char_offset_start, char_offset_end, context} | null,
   "proof_edit":     {...} | null}

Selection is a sha256-seeded random choice over filtered candidates, so re-running
this stage on the same labeled input is idempotent.
"""

from __future__ import annotations

import argparse
import hashlib
import random
import re
import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from number_edit.common import load_jsonl, write_jsonl


def is_numeric_literal(value: str) -> bool:
    v = value.strip()
    if re.fullmatch(r"-?\d+", v):
        return True
    if re.fullmatch(r"-?\d+\.\d+", v):
        return True
    if re.fullmatch(r"-?\d+\s*/\s*-?\d+", v):
        return True
    if re.fullmatch(r"-?\d+(?:\.\d+)?%", v):
        return True
    return False


def is_isolated_number(text: str, start: int, end: int, value: str) -> bool:
    """Reject candidates whose span is a *fragment* of a larger numeric token.

    Example bad label from Gemini: value="2" at offset inside "12" — the
    extracted digit is really the low-order digit of 12, and perturbing it
    would change 12 to e.g. 14 instead of producing a meaningful edit of 2.

    Rule: the char right before `start` must not be a digit or decimal point,
    and the char right after `end` must not be a digit or decimal point.
    """
    if start < 0 or end > len(text) or start >= end:
        return True  # let replace_span do its own validation
    prev_ch = text[start - 1] if start > 0 else ""
    next_ch = text[end] if end < len(text) else ""
    if prev_ch.isdigit() or prev_ch == ".":
        return False
    if next_ch.isdigit() or next_ch == ".":
        return False
    return True


def seeded_rng(problem_name: str, source: str) -> random.Random:
    digest = hashlib.sha256(f"{problem_name}:{source}".encode("utf-8")).hexdigest()
    return random.Random(int(digest[:16], 16))


def pick_one(cands: list, problem_name: str, source: str) -> dict | None:
    if not cands:
        return None
    rng = seeded_rng(problem_name, source)
    chosen = rng.choice(cands)
    return {
        "old_value": str(chosen.get("value", "")).strip(),
        "char_offset_start": int(chosen.get("char_offset_start", -1)),
        "char_offset_end": int(chosen.get("char_offset_end", -1)),
        "context": chosen.get("context", ""),
    }


def main() -> None:
    parser = argparse.ArgumentParser(description="Filter + random-pick one number per source")
    parser.add_argument("--input", default="./number_edit/data/labeled_numbers_fixed.jsonl")
    parser.add_argument("--output", default="./number_edit/data/selected_numbers.jsonl")
    parser.add_argument("--dataset", default="", help="Original dataset JSONL (for fragment filter)")
    args = parser.parse_args()

    rows = load_jsonl(args.input)
    # Need original dataset text for isolated-number check
    dataset_path = args.dataset
    problems_by_name = {}
    if dataset_path:
        for r in load_jsonl(dataset_path):
            problems_by_name[r["name"]] = r

    out_rows = []
    n_stmt = n_proof = n_both_null = 0
    n_filtered_fragment = 0

    for row in rows:
        name = row["problem_name"]
        numbers = row.get("numbers", []) or []

        numbers = [n for n in numbers if is_numeric_literal(str(n.get("value", "")))]

        # Drop candidates that are fragments of a larger numeric token
        if problems_by_name and name in problems_by_name:
            p = problems_by_name[name]
            stmt_text = str(p.get("informal_statement", "") or "")
            proof_text = str(p.get("informal_proof", "") or "")
            filtered = []
            for n in numbers:
                src = n.get("source", "")
                text = stmt_text if src == "statement" else proof_text
                s = int(n.get("char_offset_start", -1))
                e = int(n.get("char_offset_end", -1))
                if is_isolated_number(text, s, e, str(n.get("value", ""))):
                    filtered.append(n)
                else:
                    n_filtered_fragment += 1
            numbers = filtered

        stmt_cands = [n for n in numbers if n.get("source") == "statement"]
        proof_cands = [n for n in numbers if n.get("source") == "proof"]

        stmt_pick = pick_one(stmt_cands, name, "statement")
        proof_pick = pick_one(proof_cands, name, "proof")

        out_rows.append({
            "problem_name": name,
            "statement_edit": stmt_pick,
            "proof_edit": proof_pick,
        })
        if stmt_pick: n_stmt += 1
        if proof_pick: n_proof += 1
        if not stmt_pick and not proof_pick: n_both_null += 1

    write_jsonl(args.output, out_rows)

    print("=== Selection Summary ===")
    print(f"Problems:              {len(out_rows)}")
    print(f"statement_edit picks:  {n_stmt}")
    print(f"proof_edit picks:      {n_proof}")
    print(f"Both null:             {n_both_null}")
    print(f"Dropped (fragment):    {n_filtered_fragment}")
    print(f"Output:                {args.output}")


if __name__ == "__main__":
    main()
