"""Phase 1: Label proof steps using Gemini with strict quality constraints.

Reads original dataset.jsonl and outputs labeled_steps.jsonl with step
decompositions. Uses hard constraints R1/R3/R4, soft R2/R6, and 4 examples
(2 good + 2 bad).

Usage:
    GOOGLE_API_KEY=... python3 -m step_edit.label_proof_steps \
        --input  datasets_validation/minif2f/dataset.jsonl \
        --output step_edit/data/minif2f/labeled_steps.jsonl
"""
import argparse
import json
import os
import sys
import time
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

LABEL_PROMPT = """\
You are annotating a natural-language math proof for a step-deletion benchmark.

Your task: split the proof into ordered steps. For each step, identify reasoning
(the justification) and outcome (the conclusion).

==========================================================================
RULES — read carefully before annotating
==========================================================================

R1 (HARD). REASONING must be SUBSTANTIVE.
    - Must be >= 30 characters.
    - Must contain at least one: math expression ($...$), a named theorem or
      technique (e.g. "by Cauchy-Schwarz", "by induction"), or a multi-step
      computation.
    - REJECT (set deletable=false) if reasoning is ONLY a transition word or
      phrase such as: "Thus", "So", "Hence", "Therefore", "Consequently",
      "We have that", "We see that", "We get", "Notice that", "Note that",
      "It follows", "As a result", "In the diagram", "WLOG", "Clearly",
      "Obviously", "We obtain", "This gives us", "That means", "Simplifying",
      "As before", "From this".

R2 (SOFT). OUTCOME should ideally be a readable standalone claim.
    - Preferred: starts with uppercase, ends with period or $, has English context.
    - Also acceptable: formula-only like "$r = 3.$" IF reasoning is substantive.
    - Set deletable=false only if outcome is empty, or < 10 characters, or is a
      single symbol/digit with no equation structure (e.g. just "5" or "$x$").

R3 (HARD). OUTCOME must NOT be an incomplete clause that requires its antecedent.
    - REJECT if outcome starts with a relative/subordinating fragment:
      "which ...", "thereby ...", "whereby ...", "whereas ..."
    - REJECT if outcome starts with an anaphoric pronoun whose referent is
      exclusively in reasoning: "this ...", "that ...", "these ...", "those ..."
      (as sentence subject).
    - "So $x=5$.", "Hence $r=3$.", "Therefore the answer is 42." are OK — these
      are standalone transitions, not incomplete clauses.

R4 (HARD). reasoning_text + outcome_text must reconstruct full_text.
    - full_text = reasoning_text + (whitespace/punctuation) + outcome_text.
    - If you cannot cleanly split, set deletable=false.

R6 (SOFT). Prefer proofs with >= 3 steps.
    - If the proof has only 1 step, set all deletable=false.
    - If the proof has 2 steps, you may still mark one deletable if R1-R4 pass.

Be CONSERVATIVE: if unsure whether R1-R4 are satisfied, set deletable=false.
We prefer fewer high-quality edits over many noisy ones.

==========================================================================
EXAMPLES
==========================================================================

GOOD (deletable=true):
  full_text: "By the AM-GM inequality, $\\frac{{a+b}}{{2}} \\geq \\sqrt{{ab}}$. Therefore $a + b \\geq 2\\sqrt{{ab}}$."
  reasoning_text: "By the AM-GM inequality, $\\frac{{a+b}}{{2}} \\geq \\sqrt{{ab}}$."
  outcome_text: "Therefore $a + b \\geq 2\\sqrt{{ab}}$."

GOOD (formula-only outcome, reasoning is substantive):
  full_text: "Substituting $x = 0$ and $y = 3$ into $r = \\sqrt{{x^2 + y^2}}$, we get $r = 3.$"
  reasoning_text: "Substituting $x = 0$ and $y = 3$ into $r = \\sqrt{{x^2 + y^2}}$,"
  outcome_text: "$r = 3.$"

BAD (vacuous reasoning — deletable=false):
  full_text: "Consequently, $S = -14$ and $P = -38$."
  reasoning_text: "Consequently,"
  outcome_text: "$S = -14$ and $P = -38$."

BAD (degenerate outcome — deletable=false):
  full_text: "Simplifying, $013$"
  reasoning_text: "Simplifying,"
  outcome_text: "$013$"

==========================================================================
OUTPUT FORMAT
==========================================================================

Return ONLY valid JSON, no prose, no code fences:

{{
  "problem_name": "...",
  "steps": [
    {{
      "step_idx": 0,
      "full_text": "exact verbatim text from the proof",
      "reasoning_text": "the justification / computation / derivation",
      "outcome_text": "the conclusion",
      "deletable": true,
      "is_last": false
    }},
    ...
  ]
}}

List ALL steps, not just deletable ones.
Do NOT add or remove escape characters — copy text EXACTLY from the input.

--- INFORMAL PROOF ---
{proof}
"""


def label_one_problem(model, problem: dict) -> dict:
    name = problem.get("name", "unknown")
    proof = str(problem.get("informal_proof", "") or "")
    prompt = LABEL_PROMPT.format(proof=proof)

    for attempt in range(3):
        try:
            response = model.generate_content(prompt)
            text = response.text.strip()
            if "```json" in text:
                text = text.split("```json", 1)[1].split("```", 1)[0].strip()
            elif "```" in text:
                text = text.split("```", 1)[1].split("```", 1)[0].strip()

            from number_edit.common_parser import robust_json_loads
            result = robust_json_loads(text)
            result["problem_name"] = name
            return result
        except Exception as e:
            print(f"  Attempt {attempt + 1} failed for {name}: {e}")
            time.sleep(2)

    return {"problem_name": name, "steps": [], "error": "Failed after 3 attempts"}


def main() -> None:
    parser = argparse.ArgumentParser(description="Label proof steps")
    parser.add_argument("--input", required=True, help="Original dataset.jsonl")
    parser.add_argument("--output", required=True, help="Output labeled_steps.jsonl")
    parser.add_argument("--model", default="gemini-2.5-flash")
    parser.add_argument("--limit", type=int, default=0, help="Max problems (0=all)")
    args = parser.parse_args()

    import google.generativeai as genai
    api_key = os.environ.get("GOOGLE_API_KEY", "")
    if not api_key:
        sys.exit("Set GOOGLE_API_KEY environment variable")
    genai.configure(api_key=api_key)
    model = genai.GenerativeModel(args.model)

    problems = [json.loads(l) for l in open(args.input) if l.strip()]
    if args.limit > 0:
        problems = problems[:args.limit]

    done = set()
    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    if output_path.exists():
        for line in open(output_path):
            r = json.loads(line)
            done.add(r.get("problem_name"))
        print(f"Resuming: {len(done)}/{len(problems)} already done")

    with open(output_path, "a", encoding="utf-8") as fout:
        for i, prob in enumerate(problems):
            name = prob.get("name", f"row_{i}")
            if name in done:
                continue
            print(f"  [{i + 1}/{len(problems)}] {name}...", end=" ", flush=True)
            result = label_one_problem(model, prob)
            n_steps = len(result.get("steps", []))
            n_del = sum(1 for s in result.get("steps", []) if s.get("deletable"))
            print(f"{n_steps} steps, {n_del} deletable")
            fout.write(json.dumps(result, ensure_ascii=False) + "\n")
            fout.flush()
            time.sleep(0.5)

    print(f"\nDone! {output_path}")


if __name__ == "__main__":
    main()
