"""Stage 1: label ALL editable target symbols in a problem.

For each problem, Gemini returns:
  {
    "problem_name": "...",
    "symbols": [
      {"symbol": "\\leq", "family": "relation"|"operator",
       "source": "statement"|"proof",
       "char_offset_start": int, "char_offset_end": int, "context": "..."},
      ...
    ]
  }

No role labels. Excludes = and ≠ by prompt. Downstream select stage filters
and randomly picks one per source.
"""

import argparse
import json
import os
import sys
import time
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

LABEL_PROMPT = """\
You are a math annotation assistant. Given a math problem's informal statement
and informal proof, list EVERY occurrence of the following target symbols that
is SAFELY EDITABLE for perturbation. For each, specify whether it appears in
the STATEMENT or the PROOF.

TARGET SYMBOLS (LaTeX command forms are the same symbol written differently):
  relation: >  <  >=  <=    (or \\gt \\lt \\geq \\leq \\ge \\le)
  operator: +  -  ×  ·  ÷   (or \\times \\cdot \\div)

EXCLUDE: = ≠ \\neq, and any symbol not in the list above.

DEFINITION OF "safely editable":
- Changing the symbol should yield a plausibly wrong but well-formed claim or
  computation.

DO NOT label:
- Definitional: the + in "Let f(x) = x + 1" when 1 is just a definitional term.
- Structural: + in subscripts like x_{n+1}, - in superscripts like x^{2-k}.
- Inside \\frac, \\sqrt, or other LaTeX command arguments.

OUTPUT: one JSON object with exactly this shape:
{
  "problem_name": "...",
  "symbols": [
    {
      "symbol":             "<exact substring as it appears, e.g. \\\\leq or ≤>",
      "family":             "relation" or "operator",
      "source":             "statement" or "proof",
      "char_offset_start":  <int offset in the respective source text>,
      "char_offset_end":    <int offset in the respective source text>,
      "context":            "<~20 words around the symbol, as it appears>"
    }
  ]
}

Rules:
- List each occurrence SEPARATELY with its own offsets.
- "symbols" may be empty if nothing is safely editable.
- Be precise with char_offset_start / char_offset_end so that
  text[start:end] == symbol. For multi-char symbols like \\leq (4 chars), span
  is start to start+4. Offsets are within the respective source text.
- Only output the JSON object — no prose, no code fences.

--- INFORMAL STATEMENT ---
%s

--- INFORMAL PROOF ---
%s
"""


def label_one_problem(model, problem: dict) -> dict:
    name = problem.get("name", "unknown")
    stmt = str(problem.get("informal_statement", "") or "")
    proof = str(problem.get("informal_proof", "") or "")

    prompt = LABEL_PROMPT % (stmt, proof)

    for attempt in range(3):
        try:
            response = model.generate_content(prompt)
            text = response.text.strip()
            if "```json" in text:
                text = text.split("```json")[1].split("```")[0].strip()
            elif "```" in text:
                text = text.split("```")[1].split("```")[0].strip()
            from number_edit.common_parser import robust_json_loads
            result = robust_json_loads(text)
            return {
                "problem_name": name,
                "symbols": result.get("symbols") or [],
            }
        except Exception as e:
            print(f"  Attempt {attempt+1} failed for {name}: {e}")
            time.sleep(2)

    return {"problem_name": name, "symbols": [], "error": "Failed after 3 attempts"}


def main():
    parser = argparse.ArgumentParser(description="Label all editable symbols per problem")
    parser.add_argument("--input", default="./datasets_validation/minif2f/dataset.jsonl")
    parser.add_argument("--output", default="./symbol_edit/data/labeled_symbols.jsonl")
    parser.add_argument("--model", default="gemini-2.5-flash")
    parser.add_argument("--limit", type=int, default=0)
    args = parser.parse_args()

    os.makedirs(os.path.dirname(args.output), exist_ok=True)

    with open(args.input) as f:
        problems = [json.loads(line) for line in f]
    if args.limit > 0:
        problems = problems[: args.limit]
    print(f"Loaded {len(problems)} problems from {args.input}")

    import google.generativeai as genai
    genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
    model = genai.GenerativeModel(args.model)

    done_names = set()
    if os.path.exists(args.output):
        with open(args.output) as f:
            for line in f:
                try:
                    done_names.add(json.loads(line).get("problem_name", ""))
                except Exception:
                    pass
        print(f"Resuming: {len(done_names)} already labeled")

    with open(args.output, "a") as fout:
        for i, problem in enumerate(problems):
            name = problem.get("name", f"problem_{i}")
            if name in done_names:
                continue
            print(f"[{i+1}/{len(problems)}] Labeling {name}...")
            result = label_one_problem(model, problem)
            fout.write(json.dumps(result, ensure_ascii=False) + "\n")
            fout.flush()
            time.sleep(0.5)

    with open(args.output) as f:
        rows = [json.loads(line) for line in f]
    total_cands = sum(len(r.get("symbols", [])) for r in rows)
    stmt_cands = sum(sum(1 for s in r.get("symbols", []) if s.get("source") == "statement") for r in rows)
    proof_cands = sum(sum(1 for s in r.get("symbols", []) if s.get("source") == "proof") for r in rows)

    print("\n=== Summary ===")
    print(f"Problems labeled:         {len(rows)}")
    print(f"Total symbol candidates:  {total_cands}")
    print(f"  statement:              {stmt_cands}")
    print(f"  proof:                  {proof_cands}")


if __name__ == "__main__":
    main()
