"""Stage 1: label ALL editable numeric literals in a problem.

For each problem, Gemini returns:
  {
    "problem_name": "...",
    "numbers": [
      {"value": "24", "source": "statement"|"proof",
       "char_offset_start": int, "char_offset_end": int, "context": "..."},
      ...
    ]
  }

No role/priority labels. Quality gates (skip structural/subscript/definitional
numbers) are enforced by the prompt. Downstream select stage filters and picks.
"""

import argparse
import json
import os
import sys
import time
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

LABEL_PROMPT = """\
You are a math annotation assistant. Given a math problem's informal statement
and informal proof, list EVERY numeric literal that is SAFELY EDITABLE for
perturbation. For each, specify whether it appears in the STATEMENT or the PROOF.

DEFINITION OF "safely editable":
- A concrete numeric literal in digit form: integer, decimal, percent, or simple
  fraction like "1/3". Words like "nine" or "twice" are NOT editable.
- Changing this number should yield a plausibly wrong but well-formed problem.

DO NOT label:
- Structural numbers: subscripts (the 1 in x_1), superscripts (the 2 in x^2),
  indices, dimensions.
- Numbers that are part of a formula DEFINITION rather than a value (the 1 in
  "Let f(x) = x + 1" when 1 is just a definitional term, not the intended value
  to perturb).
- Numbers inside LaTeX commands like \\frac, \\sqrt when they're formatting
  arguments rather than actual parameter values.

OUTPUT: one JSON object with exactly this shape:
{
  "problem_name": "...",
  "numbers": [
    {
      "value":              "<exact substring as it appears in text>",
      "source":             "statement" or "proof",
      "char_offset_start":  <int offset in the respective source text>,
      "char_offset_end":    <int offset in the respective source text>,
      "context":            "<~20 words around the number, as it appears>"
    }
  ]
}

Rules:
- The SAME number value may appear multiple times — list each occurrence
  separately with its own offsets and source.
- "numbers" may be empty if nothing is safely editable.
- Be precise with char_offset_start / char_offset_end so that
  text[start:end] == value. Offsets are within the respective source text
  (NOT the combined text).
- Only output the JSON object — no prose, no code fences.

--- INFORMAL STATEMENT ---
%s

--- INFORMAL PROOF ---
%s
"""


def label_one_problem(model, problem: dict) -> dict:
    name = problem.get("name", "unknown")
    stmt = str(problem.get("informal_statement", "") or "")
    proof = str(problem.get("informal_proof", "") or "")

    prompt = LABEL_PROMPT % (stmt, proof)

    for attempt in range(3):
        try:
            response = model.generate_content(prompt)
            text = response.text.strip()
            if "```json" in text:
                text = text.split("```json")[1].split("```")[0].strip()
            elif "```" in text:
                text = text.split("```")[1].split("```")[0].strip()
            from number_edit.common_parser import robust_json_loads
            result = robust_json_loads(text)
            return {
                "problem_name": name,
                "numbers": result.get("numbers") or [],
            }
        except Exception as e:
            print(f"  Attempt {attempt+1} failed for {name}: {e}")
            time.sleep(2)

    return {"problem_name": name, "numbers": [], "error": "Failed after 3 attempts"}


def main():
    parser = argparse.ArgumentParser(description="Label all editable numbers per problem")
    parser.add_argument("--input", default="./datasets_validation/minif2f/dataset.jsonl")
    parser.add_argument("--output", default="./number_edit/data/labeled_numbers.jsonl")
    parser.add_argument("--model", default="gemini-2.5-flash")
    parser.add_argument("--limit", type=int, default=0)
    args = parser.parse_args()

    os.makedirs(os.path.dirname(args.output), exist_ok=True)

    with open(args.input) as f:
        problems = [json.loads(line) for line in f]
    if args.limit > 0:
        problems = problems[: args.limit]
    print(f"Loaded {len(problems)} problems from {args.input}")

    import google.generativeai as genai
    genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
    model = genai.GenerativeModel(args.model)

    done_names = set()
    if os.path.exists(args.output):
        with open(args.output) as f:
            for line in f:
                try:
                    done_names.add(json.loads(line).get("problem_name", ""))
                except Exception:
                    pass
        print(f"Resuming: {len(done_names)} already labeled")

    with open(args.output, "a") as fout:
        for i, problem in enumerate(problems):
            name = problem.get("name", f"problem_{i}")
            if name in done_names:
                continue
            print(f"[{i+1}/{len(problems)}] Labeling {name}...")
            result = label_one_problem(model, problem)
            fout.write(json.dumps(result, ensure_ascii=False) + "\n")
            fout.flush()
            time.sleep(0.5)

    with open(args.output) as f:
        rows = [json.loads(line) for line in f]
    total_cands = sum(len(r.get("numbers", [])) for r in rows)
    stmt_cands = sum(sum(1 for n in r.get("numbers", []) if n.get("source") == "statement") for r in rows)
    proof_cands = sum(sum(1 for n in r.get("numbers", []) if n.get("source") == "proof") for r in rows)

    print("\n=== Summary ===")
    print(f"Problems labeled:         {len(rows)}")
    print(f"Total number candidates:  {total_cands}")
    print(f"  statement:              {stmt_cands}")
    print(f"  proof:                  {proof_cands}")


if __name__ == "__main__":
    main()
