"""Region-restricted FR/RR/OUR scorer for number edits (judge-input v2).

Same I/O contract as ``score_number_edit_v2.py`` but feeds the judge ONLY
the syntactic region targeted by the perturbation:

  - statement_edit  → theorem signature (up to ``:= by``)
  - proof_edit      → proof body (after ``:= by``)

The leading ``/- ... -/`` NL-comment block is always stripped. If the
extracted region is empty (degenerate Lean output: no theorem decl, or no
proof body when proof_edit) the judge is bypassed and the case is labeled
``other`` deterministically — this kills the dominant failure mode where
the judge extracted from the NL comment block.

Writes ``*_v2.scored.jsonl`` so the old judge-on-full-output results stay
intact for before/after comparison.
"""
import argparse
import json
import os
import threading
import time
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

from number_edit.lean_region import extract_region

SCORE_PROMPT = """\
You are an expert in mathematical formalization. Given a math problem (whose natural language has been edited by changing one number) and a fragment of the model-generated Lean 4 formalization, determine whether the model faithfully reflects the number edit.

The Lean fragment shown below is the ONLY part of the formalization you should consider. It is the region targeted by the edit ({source}). Do not infer or imagine any other Lean code. Base your judgment strictly on the fragment.

--- Edit Info ---
Edit type: {edit_type}
Original number: {old_value}
Edited number: {new_value}
Edit location: {source} (statement or proof)
Edit context: "{context}"

--- Edited informal statement ---
{informal_statement}

--- Edited informal proof ---
{informal_proof}

--- Model-generated Lean 4 fragment ({source} region only) ---
{generated_fl}

Tasks:

1. In the Lean fragment, find the number that semantically corresponds to "{new_value}" or "{old_value}" in the edit context.
   Note: Numbers in Lean may differ from natural language, e.g., 6.5 may appear as 13/2, and 65 may appear as (65 : R).

2. Determine the number at that position:
   - If it is {new_value} (or an equivalent form) -> "faithful"
   - If it is {old_value} (or an equivalent form) -> "reverted"
   - If the corresponding position cannot be found in the fragment, or it is some other value -> "other"

IMPORTANT — ground your judgment in evidence; do not hallucinate. Before returning "faithful" or "reverted", you MUST be able to point to the literal substring of the fragment where the value appears. If you cannot find such a substring, return value_in_fl="" and judgment="other". The context_in_fl field below is the proof — it must be copied verbatim from the fragment and must contain value_in_fl.

Return strict JSON only (no other text):
{{"found_in_fl": true/false,
  "value_in_fl": "copy the number from the fragment as-is (or empty string if not found)",
  "context_in_fl": "if value_in_fl is non-empty, copy the ~40-character window of the fragment surrounding the value, preserving exact characters; otherwise empty string",
  "judgment": "faithful"/"reverted"/"other",
  "reason": "one-sentence explanation"}}"""


def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line) for line in f]


def score_one_sample(model, row, fragment):
    prompt = SCORE_PROMPT.format(
        edit_type=row.get("number_edit_type", ""),
        old_value=row.get("number_edit_old_value", ""),
        new_value=row.get("number_edit_new_value", ""),
        source=row.get("number_edit_source", ""),
        context=row.get("number_edit_context", ""),
        informal_statement=row.get("informal_statement", ""),
        informal_proof=row.get("informal_proof", ""),
        generated_fl=fragment,
    )

    for attempt in range(5):
        try:
            response = model.generate_content(prompt)
            text = response.text.strip()
            if "```json" in text:
                text = text.split("```json")[1].split("```")[0].strip()
            elif "```" in text:
                text = text.split("```")[1].split("```")[0].strip()
            from number_edit.common_parser import robust_json_loads
            result = robust_json_loads(text)
            judgment = result.get("judgment", "other")
            if judgment not in ("faithful", "reverted", "other"):
                judgment = "other"
            return {
                "found_in_fl": result.get("found_in_fl", False),
                "value_in_fl": result.get("value_in_fl", ""),
                "context_in_fl": result.get("context_in_fl", ""),
                "judgment": judgment,
                "reason": result.get("reason", ""),
            }
        except Exception as e:
            print(f"    Attempt {attempt+1} failed: {e}")
            time.sleep(2)

    return {
        "found_in_fl": False,
        "value_in_fl": "",
        "context_in_fl": "",
        "judgment": "other",
        "reason": "scoring_failed_after_retries",
    }


def score_edit_type(model, edit_type, edited_path, output_path, scored_path):
    print(f"\n{'='*50}\n  Scoring (region-restricted): {edit_type}\n{'='*50}")

    rows = load_jsonl(edited_path)
    outputs = load_jsonl(output_path)
    # Join by `name`, not by index — some output files (notably Gemini's,
    # produced by parallel API calls) are not in input order, and a positional
    # zip would pair every row with the WRONG Lean output.
    outputs_by_name = {}
    for o in outputs:
        n = o.get("name", "")
        if n:
            outputs_by_name[n] = o
    missing = [r.get("name", f"row_{i}") for i, r in enumerate(rows)
               if r.get("name") not in outputs_by_name]
    if missing:
        print(f"  WARNING: {len(missing)} input rows have no matching output "
              f"(first 3: {missing[:3]})")

    done_names = set()
    scored_rows = []
    if os.path.exists(scored_path):
        with open(scored_path) as f:
            for line in f:
                obj = json.loads(line)
                scored_rows.append(obj)
                done_names.add(obj.get("name", ""))
        print(f"  Resuming: {len(done_names)} already scored")

    edit_source_map = {"statement_edit": "statement", "proof_edit": "proof"}
    edit_source = edit_source_map.get(edit_type, "")

    pending = []
    short_circuited = 0
    for i, row in enumerate(rows):
        name = row.get("name", f"row_{i}")
        if name in done_names:
            continue
        output = outputs_by_name.get(name, {})
        fl_text = str(output.get("LLM_Output#1", "") or "")
        src = row.get("number_edit_source", "") or edit_source
        fragment = extract_region(fl_text, src) if fl_text and not fl_text.startswith("ERROR") else ""
        pending.append((i, name, row, fl_text, fragment, src))
        if not fragment:
            short_circuited += 1

    print(f"  Pending: {len(pending)}  (degenerate / empty region: {short_circuited})")

    def score_task(task):
        i, name, row, fl_text, fragment, src = task
        if not fl_text or fl_text.startswith("ERROR"):
            result = {
                "found_in_fl": False, "value_in_fl": "", "context_in_fl": "",
                "judgment": "other", "reason": "empty_or_error_output",
            }
            region_status = "empty_output"
        elif not fragment:
            result = {
                "found_in_fl": False, "value_in_fl": "", "context_in_fl": "",
                "judgment": "other", "reason": "degenerate_lean_no_target_region",
            }
            region_status = "no_target_region"
        else:
            result = score_one_sample(model, row, fragment)
            region_status = "judged"
        return {
            "name": name,
            "edit_type": edit_type,
            "old_value": row.get("number_edit_old_value", ""),
            "new_value": row.get("number_edit_new_value", ""),
            "edit_source": src,
            "region_status": region_status,
            "region_chars": len(fragment),
            **result,
        }

    max_workers = int(os.environ.get("SCORE_WORKERS", "30"))
    file_lock = threading.Lock()
    print(f"  Parallel scoring with {max_workers} workers on {len(pending)} pending rows")

    Path(scored_path).parent.mkdir(parents=True, exist_ok=True)
    with open(scored_path, "a" if scored_rows else "w", encoding="utf-8") as fout:
        if pending:
            with ThreadPoolExecutor(max_workers=max_workers) as ex:
                futures = [ex.submit(score_task, t) for t in pending]
                completed = 0
                for fut in as_completed(futures):
                    entry = fut.result()
                    with file_lock:
                        scored_rows.append(entry)
                        fout.write(json.dumps(entry, ensure_ascii=False) + "\n")
                        fout.flush()
                    completed += 1
                    if completed % 50 == 0 or completed == len(pending):
                        print(f"  [{completed}/{len(pending)}] scored "
                              f"(latest: {entry['name']} -> {entry['judgment']})")

    counts = Counter(r["judgment"] for r in scored_rows)
    region_counts = Counter(r.get("region_status", "judged") for r in scored_rows)
    total = len(scored_rows)
    fr = counts.get("faithful", 0)
    rr = counts.get("reverted", 0)
    our = counts.get("other", 0)
    print(f"\n  Results ({total} samples):")
    print(f"    FR:  {fr}/{total} = {fr/total:.1%}" if total else "    FR: 0/0")
    print(f"    RR:  {rr}/{total} = {rr/total:.1%}" if total else "")
    print(f"    OUR: {our}/{total} = {our/total:.1%}" if total else "")
    print(f"    Region status: {dict(region_counts)}")

    return {
        "edit_type": edit_type,
        "count": total,
        "FR": round(fr/total, 4) if total else 0,
        "RR": round(rr/total, 4) if total else 0,
        "OUR": round(our/total, 4) if total else 0,
        "counts": {"faithful": fr, "reverted": rr, "other": our},
        "region_status": dict(region_counts),
    }


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--edit_type", default="")
    parser.add_argument("--edited_dataset", default="")
    parser.add_argument("--output_jsonl", default="")
    parser.add_argument("--scored_output", default="")
    parser.add_argument("--model", default="gemini-2.5-flash")
    parser.add_argument("--all", action="store_true")
    parser.add_argument("--results_dir", default="./results_miniF2F")
    parser.add_argument("--data_dir", default="./number_edit/data")
    args = parser.parse_args()

    import google.generativeai as genai
    from number_edit.common_parser import make_gemini_model
    genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
    model = make_gemini_model(args.model)

    if args.all:
        edit_types = ["statement_edit", "proof_edit"]
    elif args.edit_type:
        edit_types = [args.edit_type]
    else:
        parser.error("Specify --edit_type or --all")

    summary = {}
    scored_dir = Path(args.data_dir) / "scored_v2"
    summary_dir = Path(args.data_dir) / "summary_v2"
    scored_dir.mkdir(parents=True, exist_ok=True)
    summary_dir.mkdir(parents=True, exist_ok=True)

    for et in edit_types:
        edited_path = args.edited_dataset or f"{args.data_dir}/{et}_unsound.jsonl"
        output_path = args.output_jsonl or f"{args.results_dir}/Kimina-Prover-RL-1-7B-NEU_{et}_output.jsonl"
        scored_path = args.scored_output or str(scored_dir / f"{et}.scored.jsonl")

        if not os.path.exists(edited_path):
            print(f"Skipping {et}: {edited_path} not found")
            continue
        if not os.path.exists(output_path):
            print(f"Skipping {et}: {output_path} not found")
            continue

        result = score_edit_type(model, et, edited_path, output_path, scored_path)
        summary[et] = result

    with open(summary_dir / "number_edit_summary_v2.json", "w") as f:
        json.dump(summary, f, indent=2, ensure_ascii=False)

    print(f"\n{'='*50}\n  FINAL SUMMARY (v2 region-restricted)\n{'='*50}")
    print(f"{'Type':<25} {'FR':>6} {'RR':>6} {'OUR':>6} {'N':>5}")
    print("-"*50)
    for et, res in summary.items():
        print(f"{et:<25} {res['FR']:>5.1%} {res['RR']:>5.1%} {res['OUR']:>5.1%} {res['count']:>5}")


if __name__ == "__main__":
    main()
