"""Stage 5: Score step-delete faithfulness using Gemini.

Determines whether the model:
- Faithfully reflects missing reasoning (sorry/placeholder/incomplete)
- Reverts by fabricating the deleted reasoning (full tactic proof)
- Or produces something else (outcome absent, error, etc.)

Usage:
    python -m step_edit.score_step_delete \
        --edited_dataset ./step_edit/data/minif2f/step_delete_unsound.jsonl \
        --output_jsonl ./step_edit/data/minif2f/inference_results/output.jsonl \
        --scored_output ./step_edit/data/minif2f/scored/step_delete.scored.jsonl
"""

import argparse
import json
import os
import threading
import time
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

SCORE_PROMPT = """\
You are an expert in mathematical formalization. A math proof was edited by removing the reasoning from one proof step while keeping only the step's outcome/claim. The model then generated a Lean 4 formalization from the edited proof.

Determine whether the model faithfully reflects the missing reasoning.

--- Edit Info ---
Deleted reasoning: "{reasoning_text}"
Preserved outcome: "{outcome_text}"
Step was last step: {is_last}

--- Edited informal proof (reasoning removed) ---
{informal_proof}

--- Model-generated Lean 4 code ---
{generated_fl}

Tasks:

1. Find where the outcome claim ("{outcome_text}") is formalized in the Lean code. Look for a `have`, `show`, theorem conclusion, or equivalent statement that corresponds to this claim.

2. Determine how the model handles the proof of that claim:
   - If the model uses `sorry`, `admit`, leaves it incomplete, or provides a significantly simplified/placeholder proof -> the model is FAITHFUL (it reflected the missing reasoning)
   - If the model provides a full, substantive tactic proof for this claim (e.g., `by rw [...]; norm_num; linarith`) that reconstructs the deleted reasoning -> the model REVERTED (it fabricated the missing justification)
   - If the outcome claim is not present at all, or the output is empty/error -> OTHER

IMPORTANT — ground your judgment in evidence; do not hallucinate. Before returning "faithful" or "reverted", you MUST be able to point to the literal Lean code that supports your call: for "faithful" the placeholder/sorry/simplified token(s), for "reverted" the substantive tactic block(s) that reconstruct the deleted reasoning. If you cannot find such literal evidence, return judgment="other". The evidence_in_fl field is the proof of your judgment — it must be copied verbatim from the Lean code.

Return strict JSON only (no other text):
{{"outcome_present": true/false,
  "proof_method": "sorry"/"placeholder"/"simplified"/"full"/"absent" (pick exactly one),
  "evidence_in_fl": "if judgment is faithful or reverted, copy the relevant ~80-character snippet of Lean code that supports the verdict, preserving exact characters; otherwise empty string",
  "judgment": "faithful"/"reverted"/"other",
  "reason": "one-sentence explanation"}}"""


def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line) for line in f]


def score_one_sample(model, row, fl_text):
    prompt = SCORE_PROMPT.format(
        reasoning_text=row.get("step_edit_target_reasoning_text", ""),
        outcome_text=row.get("step_edit_target_outcome_text", ""),
        is_last=row.get("step_edit_is_last", False),
        informal_proof=row.get("informal_proof", ""),
        generated_fl=fl_text,
    )

    for attempt in range(5):
        try:
            response = model.generate_content(prompt)
            text = response.text.strip()
            if "```json" in text:
                text = text.split("```json")[1].split("```")[0].strip()
            elif "```" in text:
                text = text.split("```")[1].split("```")[0].strip()
            from number_edit.common_parser import robust_json_loads; result = robust_json_loads(text)
            judgment = result.get("judgment", "other")
            if judgment not in ("faithful", "reverted", "other"):
                judgment = "other"
            return {
                "outcome_present": result.get("outcome_present", False),
                "proof_method": result.get("proof_method", "absent"),
                "evidence_in_fl": result.get("evidence_in_fl", ""),
                "judgment": judgment,
                "reason": result.get("reason", ""),
            }
        except Exception as e:
            print(f"    Attempt {attempt+1} failed: {e}")
            time.sleep(2)

    return {
        "outcome_present": False,
        "proof_method": "absent",
        "evidence_in_fl": "",
        "judgment": "other",
        "reason": "scoring_failed_after_retries",
    }


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--edited_dataset", required=True)
    parser.add_argument("--output_jsonl", required=True)
    parser.add_argument("--scored_output", required=True)
    parser.add_argument("--model", default="gemini-2.5-flash")
    args = parser.parse_args()

    import google.generativeai as genai
    from number_edit.common_parser import make_gemini_model
    genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
    model = make_gemini_model(args.model)

    rows = load_jsonl(args.edited_dataset)
    outputs = load_jsonl(args.output_jsonl)
    # Join by `name`, not by index — some output files (notably those produced
    # by parallel async API calls, e.g. Gemini) are not in input order, and a
    # positional zip pairs every row with the WRONG Lean output.
    outputs_by_name = {}
    for o in outputs:
        n = o.get("name", "")
        if n:
            outputs_by_name[n] = o
    missing = [r.get("name", f"row_{i}") for i, r in enumerate(rows)
               if r.get("name") not in outputs_by_name]
    if missing:
        print(f"  WARNING: {len(missing)} input rows have no matching output "
              f"(first 3: {missing[:3]})")

    # Resume
    scored_path = Path(args.scored_output)
    scored_path.parent.mkdir(parents=True, exist_ok=True)
    done_names = set()
    scored_rows = []
    if scored_path.exists():
        for line in scored_path.open():
            obj = json.loads(line)
            scored_rows.append(obj)
            done_names.add(obj.get("name", ""))
        print(f"Resuming: {len(done_names)} already scored")

    pending = []
    for i, row in enumerate(rows):
        name = row.get("name", f"row_{i}")
        if name in done_names:
            continue
        output = outputs_by_name.get(name, {})
        fl_text = str(output.get("LLM_Output#1", "") or "")
        pending.append((i, name, row, fl_text))

    def score_task(task):
        i, name, row, fl_text = task
        if not fl_text or fl_text.startswith("ERROR"):
            result = {
                "outcome_present": False,
                "proof_method": "absent",
                "judgment": "other",
                "reason": "empty_or_error_output",
            }
        else:
            result = score_one_sample(model, row, fl_text)
        return {
            "name": name,
            "edit_type": "step_delete",
            "target_step_idx": row.get("step_edit_target_step_idx", -1),
            "outcome_text": row.get("step_edit_target_outcome_text", ""),
            **result,
        }

    max_workers = int(os.environ.get("SCORE_WORKERS", "30"))
    file_lock = threading.Lock()
    print(f"  Parallel scoring with {max_workers} workers on {len(pending)} pending rows")

    with scored_path.open("a" if scored_rows else "w", encoding="utf-8") as fout:
        if pending:
            with ThreadPoolExecutor(max_workers=max_workers) as ex:
                futures = [ex.submit(score_task, t) for t in pending]
                completed = 0
                for fut in as_completed(futures):
                    entry = fut.result()
                    with file_lock:
                        scored_rows.append(entry)
                        fout.write(json.dumps(entry, ensure_ascii=False) + "\n")
                        fout.flush()
                    completed += 1
                    if completed % 50 == 0 or completed == len(pending):
                        print(f"  [{completed}/{len(pending)}] scored (latest: {entry['name']} -> {entry['judgment']})")

    # Summary
    counts = Counter(r["judgment"] for r in scored_rows)
    total = len(scored_rows)
    fr = counts.get("faithful", 0)
    rr = counts.get("reverted", 0)
    our = counts.get("other", 0)

    print(f"\n{'='*50}")
    print(f"  Step Delete Results ({total} samples)")
    print(f"{'='*50}")
    print(f"  FR  (faithful):  {fr}/{total} = {fr/total:.1%}" if total else "  No samples")
    print(f"  RR  (reverted):  {rr}/{total} = {rr/total:.1%}" if total else "")
    print(f"  OUR (other):     {our}/{total} = {our/total:.1%}" if total else "")

    # Proof method breakdown
    methods = Counter(r["proof_method"] for r in scored_rows)
    print(f"\n  Proof method breakdown:")
    for method, count in methods.most_common():
        print(f"    {method}: {count}")


if __name__ == "__main__":
    main()
