"""Stage 2: filter labeled symbols and randomly pick one per source.

Reads labeled_symbols.jsonl (output of label_symbol_roles.py) and writes
selected_symbols.jsonl in the format consumed by build_symbol_edit_unsound.py:

  {"problem_name": ...,
   "statement_edit": {symbol, family, char_offset_start, char_offset_end, context} | null,
   "proof_edit":     {...} | null}

Selection is a sha256-seeded random choice over filtered candidates, so re-running
this stage on the same labeled input is idempotent.
"""

from __future__ import annotations

import argparse
import hashlib
import random
import re
import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from symbol_edit.common import SYMBOL_SWAP, load_jsonl, write_jsonl


EXCLUDED = {"=", "≠", "\\neq"}

# Relations whose chain membership matters. Operator-family symbols (+/-, times,
# cdot, div) never form chains, so they are not checked.
RELATIONS = {"<", ">", "\\le", "\\ge", "\\leq", "\\geq", "<=", ">=", "≤", "≥"}
_REL_PAT = re.compile(r"(\\leq|\\geq|\\le|\\ge|<=|>=|≤|≥|<|>)")


def is_valid_target(symbol: str) -> bool:
    s = symbol.strip()
    if s in EXCLUDED:
        return False
    return s in SYMBOL_SWAP


def is_in_relation_chain(text: str, start: int, end: int) -> bool:
    """Return True iff the candidate at ``[start:end]`` is part of a relation chain.

    A relation chain is a math region containing 2+ relation operators
    (``<``, ``>``, ``\\le``, ``\\ge``, ``≤``, ``≥``). Editing one relation
    in ``0 < p < 15`` produces ``0 > p < 15``, which is logically impossible
    and doesn't cleanly test symbol-level faithfulness — so any candidate
    inside such a chain is dropped at selection time.

    The "math region" is bounded on each side by ``.``, ``\\n``, or ``;``
    (statement terminators), or a hard cap of 80 characters.

    Only relation-family symbols need this check; operator-family symbols
    never form chains, so the caller should skip this function for those.
    """
    if start < 0 or end <= start:
        return False
    L = start
    while L > 0 and text[L - 1] not in ".\n;" and start - L < 80:
        L -= 1
    R = end
    while R < len(text) and text[R] not in ".\n;" and R - end < 80:
        R += 1
    return len(_REL_PAT.findall(text[L:R])) >= 2


def is_inside_script_group(text: str, start: int) -> bool:
    """Return True if position `start` is inside a ``_{...}`` or ``^{...}`` group.

    Handles nested braces. Stops at top-level LaTeX content so a `+` inside the
    subscript of ``x_{n+1}`` is detected but a `+` in ``x_{n+1} + y_{m}`` is not.
    """
    if start <= 0:
        return False
    i = 0
    while i < len(text):
        if i + 1 < len(text) and text[i] in "_^" and text[i + 1] == "{":
            # Enter group at i, content starts at i+2
            depth = 1
            j = i + 2
            while j < len(text) and depth > 0:
                if text[j] == "{":
                    depth += 1
                elif text[j] == "}":
                    depth -= 1
                    if depth == 0:
                        break
                j += 1
            # Group content span: [i+2, j), brace at j
            if i + 2 <= start < j:
                return True
            i = j + 1
        else:
            i += 1
    return False


def seeded_rng(problem_name: str, source: str) -> random.Random:
    digest = hashlib.sha256(f"{problem_name}:{source}".encode("utf-8")).hexdigest()
    return random.Random(int(digest[:16], 16))


def pick_one(cands: list, problem_name: str, source: str) -> dict | None:
    """Pick one candidate with sha256-seeded randomness.

    Relation candidates (``<``, ``>``, ``\\le``, etc.) are preferred: if any
    relation candidates exist, the random pick is restricted to the relation
    subset and operator candidates are ignored. This compensates for the fact
    that operator-family targets (``+``, ``-``, ``\\times``, ``\\cdot``,
    ``\\div``) vastly outnumber relation-family targets in typical math proofs
    (~90% vs ~10%), so uniform random picks would heavily bias Experiment 3
    toward operator faithfulness while starving relation coverage.
    """
    if not cands:
        return None
    rng = seeded_rng(problem_name, source)
    rel_cands = [
        c for c in cands
        if str(c.get("symbol", "")).strip() in RELATIONS
    ]
    pool = rel_cands if rel_cands else cands
    chosen = rng.choice(pool)
    return {
        "symbol": str(chosen.get("symbol", "")).strip(),
        "family": chosen.get("family", ""),
        "char_offset_start": int(chosen.get("char_offset_start", -1)),
        "char_offset_end": int(chosen.get("char_offset_end", -1)),
        "context": chosen.get("context", ""),
    }


def main() -> None:
    parser = argparse.ArgumentParser(description="Filter + random-pick one symbol per source")
    parser.add_argument("--input", default="./symbol_edit/data/labeled_symbols_fixed.jsonl")
    parser.add_argument("--output", default="./symbol_edit/data/selected_symbols.jsonl")
    parser.add_argument("--dataset", default="", help="Original dataset JSONL (for subscript filter)")
    args = parser.parse_args()

    rows = load_jsonl(args.input)
    problems_by_name = {}
    if args.dataset:
        for r in load_jsonl(args.dataset):
            problems_by_name[r["name"]] = r

    out_rows = []
    n_stmt = n_proof = n_both_null = 0
    n_filtered_script = 0
    n_filtered_chain = 0

    for row in rows:
        name = row["problem_name"]
        symbols = row.get("symbols", []) or []

        symbols = [s for s in symbols if is_valid_target(str(s.get("symbol", "")))]

        # Drop symbols that are (a) inside a subscript/superscript brace group
        # or (b) part of a relation chain (editing would produce impossible
        # inequalities like 0 > p < 15).
        if problems_by_name and name in problems_by_name:
            p = problems_by_name[name]
            stmt_text = str(p.get("informal_statement", "") or "")
            proof_text = str(p.get("informal_proof", "") or "")
            filtered = []
            for s in symbols:
                src = s.get("source", "")
                text = stmt_text if src == "statement" else proof_text
                sym = str(s.get("symbol", "")).strip()
                pos_start = int(s.get("char_offset_start", -1))
                pos_end = int(s.get("char_offset_end", -1))
                if is_inside_script_group(text, pos_start):
                    n_filtered_script += 1
                    continue
                if sym in RELATIONS and is_in_relation_chain(text, pos_start, pos_end):
                    n_filtered_chain += 1
                    continue
                filtered.append(s)
            symbols = filtered

        stmt_cands = [s for s in symbols if s.get("source") == "statement"]
        proof_cands = [s for s in symbols if s.get("source") == "proof"]

        stmt_pick = pick_one(stmt_cands, name, "statement")
        proof_pick = pick_one(proof_cands, name, "proof")

        out_rows.append({
            "problem_name": name,
            "statement_edit": stmt_pick,
            "proof_edit": proof_pick,
        })
        if stmt_pick: n_stmt += 1
        if proof_pick: n_proof += 1
        if not stmt_pick and not proof_pick: n_both_null += 1

    write_jsonl(args.output, out_rows)

    print("=== Selection Summary ===")
    print(f"Problems:              {len(out_rows)}")
    print(f"statement_edit picks:  {n_stmt}")
    print(f"proof_edit picks:      {n_proof}")
    print(f"Both null:             {n_both_null}")
    print(f"Dropped (subscript):   {n_filtered_script}")
    print(f"Dropped (rel chain):   {n_filtered_chain}")
    print(f"Output:                {args.output}")


if __name__ == "__main__":
    main()
