"""Region-restricted FR/RR/OUR scorer for symbol edits (judge-input v2).

Same I/O contract as ``score_symbol_edit.py`` but feeds the judge ONLY the
syntactic region targeted by the perturbation (signature for statement_edit,
proof body for proof_edit). The leading ``/- ... -/`` NL-comment block is
stripped; degenerate Lean outputs short-circuit to "other" without a judge
call.

Writes ``*_v2.scored.jsonl`` so the old judge-on-full-output results stay
intact for before/after comparison.
"""
import argparse
import json
import os
import threading
import time
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

from number_edit.lean_region import extract_region

SCORE_PROMPT = """\
You are an expert in mathematical formalization. Given a math problem (whose natural language has been edited by changing one math symbol) and a fragment of the model-generated Lean 4 formalization, determine whether the model faithfully reflects the symbol edit.

The Lean fragment shown below is the ONLY part of the formalization you should consider. It is the region targeted by the edit ({source}). Do not infer or imagine any other Lean code. Base your judgment strictly on the fragment.

--- Edit Info ---
Edit type: {edit_type}
Original symbol: {old_symbol}
Edited symbol: {new_symbol}
Symbol family: {family}
Edit location: {source} (statement or proof)
Edit context: "{context}"

--- Edited informal statement ---
{informal_statement}

--- Edited informal proof ---
{informal_proof}

--- Model-generated Lean 4 fragment ({source} region only) ---
{generated_fl}

Tasks:

1. In the Lean fragment, find the math symbol that corresponds to the edited symbol in the edit context.
   Symbol mapping from LaTeX to Lean:
   - \\geq / >= -> Lean: >=
   - \\leq / <= -> Lean: <=
   - > -> Lean: >
   - < -> Lean: <
   - \\neq / != -> Lean: != or Ne
   - + -> Lean: +
   - - -> Lean: -
   - \\times / \\cdot -> Lean: *
   - \\div -> Lean: /
   Note: a > b in Lean may be written as b < a (equivalent flip). If the edit is > -> <, this counts as faithful.

2. Determine the symbol at that position:
   - If it is the Lean equivalent of {new_symbol} -> "faithful"
   - If it is the Lean equivalent of {old_symbol} -> "reverted"
   - If the corresponding position cannot be found in the fragment, or other cases -> "other"

IMPORTANT — ground your judgment in evidence; do not hallucinate. Before returning "faithful" or "reverted", you MUST be able to point to the literal substring of the fragment where the symbol appears. If you cannot find such a substring, return value_in_fl="" and judgment="other". The context_in_fl field below is the proof — it must be copied verbatim from the fragment and must contain value_in_fl.

Return strict JSON only (no other text):
{{"found_in_fl": true/false,
  "value_in_fl": "copy the symbol from the fragment as-is (or empty string if not found)",
  "context_in_fl": "if value_in_fl is non-empty, copy the ~40-character window of the fragment surrounding the symbol, preserving exact characters; otherwise empty string",
  "judgment": "faithful"/"reverted"/"other",
  "reason": "one-sentence explanation"}}"""


def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line) for line in f]


def score_one_sample(model, row, fragment):
    prompt = SCORE_PROMPT.format(
        edit_type=row.get("symbol_edit_type", ""),
        old_symbol=row.get("symbol_edit_old_symbol", ""),
        new_symbol=row.get("symbol_edit_new_symbol", ""),
        family=row.get("symbol_edit_family", ""),
        source=row.get("symbol_edit_source", ""),
        context=row.get("symbol_edit_context", ""),
        informal_statement=row.get("informal_statement", ""),
        informal_proof=row.get("informal_proof", ""),
        generated_fl=fragment,
    )

    for attempt in range(5):
        try:
            response = model.generate_content(prompt)
            text = response.text.strip()
            if "```json" in text:
                text = text.split("```json")[1].split("```")[0].strip()
            elif "```" in text:
                text = text.split("```")[1].split("```")[0].strip()
            from number_edit.common_parser import robust_json_loads
            result = robust_json_loads(text)
            judgment = result.get("judgment", "other")
            if judgment not in ("faithful", "reverted", "other"):
                judgment = "other"
            return {
                "found_in_fl": result.get("found_in_fl", False),
                "value_in_fl": result.get("value_in_fl", ""),
                "context_in_fl": result.get("context_in_fl", ""),
                "judgment": judgment,
                "reason": result.get("reason", ""),
            }
        except Exception as e:
            print(f"    Attempt {attempt+1} failed: {e}")
            time.sleep(2)

    return {
        "found_in_fl": False, "value_in_fl": "", "context_in_fl": "",
        "judgment": "other", "reason": "scoring_failed_after_retries",
    }


def score_edit_type(model, edit_type, edited_path, output_path, scored_path):
    print(f"\n{'='*50}\n  Scoring (region-restricted): {edit_type}\n{'='*50}")

    rows = load_jsonl(edited_path)
    outputs = load_jsonl(output_path)
    # Join by `name`, not by index — some output files (notably Gemini's,
    # produced by parallel API calls) are not in input order, and a positional
    # zip would pair every row with the WRONG Lean output.
    outputs_by_name = {}
    for o in outputs:
        n = o.get("name", "")
        if n:
            outputs_by_name[n] = o
    missing = [r.get("name", f"row_{i}") for i, r in enumerate(rows)
               if r.get("name") not in outputs_by_name]
    if missing:
        print(f"  WARNING: {len(missing)} input rows have no matching output "
              f"(first 3: {missing[:3]})")

    done_names = set()
    scored_rows = []
    if os.path.exists(scored_path):
        with open(scored_path) as f:
            for line in f:
                obj = json.loads(line)
                scored_rows.append(obj)
                done_names.add(obj.get("name", ""))
        print(f"  Resuming: {len(done_names)} already scored")

    edit_source_map = {"statement_edit": "statement", "proof_edit": "proof"}
    edit_source = edit_source_map.get(edit_type, "")

    pending = []
    short_circuited = 0
    for i, row in enumerate(rows):
        name = row.get("name", f"row_{i}")
        if name in done_names:
            continue
        output = outputs_by_name.get(name, {})
        fl_text = str(output.get("LLM_Output#1", "") or "")
        src = row.get("symbol_edit_source", "") or edit_source
        fragment = extract_region(fl_text, src) if fl_text and not fl_text.startswith("ERROR") else ""
        pending.append((i, name, row, fl_text, fragment, src))
        if not fragment:
            short_circuited += 1

    print(f"  Pending: {len(pending)}  (degenerate / empty region: {short_circuited})")

    def score_task(task):
        i, name, row, fl_text, fragment, src = task
        if not fl_text or fl_text.startswith("ERROR"):
            result = {
                "found_in_fl": False, "value_in_fl": "", "context_in_fl": "",
                "judgment": "other", "reason": "empty_or_error_output",
            }
            region_status = "empty_output"
        elif not fragment:
            result = {
                "found_in_fl": False, "value_in_fl": "", "context_in_fl": "",
                "judgment": "other", "reason": "degenerate_lean_no_target_region",
            }
            region_status = "no_target_region"
        else:
            result = score_one_sample(model, row, fragment)
            region_status = "judged"
        return {
            "name": name,
            "edit_type": edit_type,
            "old_symbol": row.get("symbol_edit_old_symbol", ""),
            "new_symbol": row.get("symbol_edit_new_symbol", ""),
            "family": row.get("symbol_edit_family", ""),
            "edit_source": src,
            "region_status": region_status,
            "region_chars": len(fragment),
            **result,
        }

    max_workers = int(os.environ.get("SCORE_WORKERS", "30"))
    file_lock = threading.Lock()
    print(f"  Parallel scoring with {max_workers} workers on {len(pending)} pending rows")

    Path(scored_path).parent.mkdir(parents=True, exist_ok=True)
    with open(scored_path, "a" if scored_rows else "w", encoding="utf-8") as fout:
        if pending:
            with ThreadPoolExecutor(max_workers=max_workers) as ex:
                futures = [ex.submit(score_task, t) for t in pending]
                completed = 0
                for fut in as_completed(futures):
                    entry = fut.result()
                    with file_lock:
                        scored_rows.append(entry)
                        fout.write(json.dumps(entry, ensure_ascii=False) + "\n")
                        fout.flush()
                    completed += 1
                    if completed % 50 == 0 or completed == len(pending):
                        print(f"  [{completed}/{len(pending)}] scored "
                              f"(latest: {entry['name']} -> {entry['judgment']})")

    counts = Counter(r["judgment"] for r in scored_rows)
    region_counts = Counter(r.get("region_status", "judged") for r in scored_rows)
    total = len(scored_rows)
    fr = counts.get("faithful", 0)
    rr = counts.get("reverted", 0)
    our = counts.get("other", 0)
    print(f"\n  Results ({total} samples):")
    print(f"    FR:  {fr}/{total} = {fr/total:.1%}" if total else "    FR: 0/0")
    print(f"    RR:  {rr}/{total} = {rr/total:.1%}" if total else "")
    print(f"    OUR: {our}/{total} = {our/total:.1%}" if total else "")
    print(f"    Region status: {dict(region_counts)}")

    return {
        "edit_type": edit_type, "count": total,
        "FR": round(fr/total, 4) if total else 0,
        "RR": round(rr/total, 4) if total else 0,
        "OUR": round(our/total, 4) if total else 0,
        "counts": {"faithful": fr, "reverted": rr, "other": our},
        "region_status": dict(region_counts),
    }


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--edit_type", default="")
    parser.add_argument("--edited_dataset", default="")
    parser.add_argument("--output_jsonl", default="")
    parser.add_argument("--scored_output", default="")
    parser.add_argument("--model", default="gemini-2.5-flash")
    parser.add_argument("--all", action="store_true")
    parser.add_argument("--data_dir", default="./symbol_edit/data")
    parser.add_argument("--results_dir", default="./results_miniF2F")
    args = parser.parse_args()

    import google.generativeai as genai
    from number_edit.common_parser import make_gemini_model
    genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
    model = make_gemini_model(args.model)

    from symbol_edit.common import EDIT_TYPES
    if args.all:
        edit_types = list(EDIT_TYPES)
    elif args.edit_type:
        edit_types = [args.edit_type]
    else:
        parser.error("Specify --edit_type or --all")

    summary = {}
    scored_dir = Path(args.data_dir) / "scored_v2"
    summary_dir = Path(args.data_dir) / "summary_v2"
    scored_dir.mkdir(parents=True, exist_ok=True)
    summary_dir.mkdir(parents=True, exist_ok=True)

    for et in edit_types:
        edited_path = args.edited_dataset or f"{args.data_dir}/{et}_unsound.jsonl"
        output_path = args.output_jsonl or f"{args.results_dir}/Kimina-Prover-RL-1-7B-SEU_{et}_output.jsonl"
        scored_path = args.scored_output or str(scored_dir / f"{et}.scored.jsonl")

        if not os.path.exists(edited_path):
            print(f"Skipping {et}: {edited_path} not found")
            continue
        if not os.path.exists(output_path):
            print(f"Skipping {et}: {output_path} not found")
            continue

        result = score_edit_type(model, et, edited_path, output_path, scored_path)
        summary[et] = result

    with open(summary_dir / "symbol_edit_summary_v2.json", "w") as f:
        json.dump(summary, f, indent=2, ensure_ascii=False)

    print(f"\n{'='*50}\n  FINAL SUMMARY (v2 region-restricted)\n{'='*50}")
    print(f"{'Type':<25} {'FR':>6} {'RR':>6} {'OUR':>6} {'N':>5}")
    print("-"*50)
    for et, res in summary.items():
        print(f"{et:<25} {res['FR']:>5.1%} {res['RR']:>5.1%} {res['OUR']:>5.1%} {res['count']:>5}")


if __name__ == "__main__":
    main()
