"""Supplement TC + StmtSC + ProofSC using SC v3 (Decoupled).

Driver for the only supported SC method. Reads inference output (LLM_Output#k
fields), runs Lean type-check + 2-call decoupled SC judge, writes per-sample
results to a sharded jsonl.

Output paths use the `_tcsc_v3_shard{N}of{T}.jsonl` convention.
"""
from __future__ import annotations

import argparse
import json
import os
import re as _re
import sys
import time
from pathlib import Path

_ROOT = Path(__file__).resolve().parents[2]
if str(_ROOT) not in sys.path:
    sys.path.insert(0, str(_ROOT))

_LEAN_DIR = _ROOT / "LEAN_interaction"
if str(_LEAN_DIR) not in sys.path:
    sys.path.insert(0, str(_LEAN_DIR))
from checkLEAN import check_repl, write_basic_lean, PersistentLeanREPL  # noqa: E402

from llm_inference.sc_combined_v3 import run_sc_v3 as run_sc  # noqa: E402


def load_jsonl(path: Path) -> list[dict]:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return rows


def run_tc_for_row(llm_output: str, lean_file_path: Path, project_dir: Path,
                   repl: "PersistentLeanREPL | None" = None) -> tuple[bool, str]:
    try:
        write_basic_lean("", llm_output, lean_file_path)
        if repl is not None:
            ok, out = repl.check(lean_file_path)
        else:
            ok, out = check_repl(lean_file_path, project_dir)
        return bool(ok), "" if ok else (out or "")[:500]
    except BaseException as e:
        return False, f"[tc crash] {type(e).__name__}: {e}"


def main():
    p = argparse.ArgumentParser(description="Supplement TC + SC v2 evaluator")
    p.add_argument("--input", required=True)
    p.add_argument("--output", required=True)
    p.add_argument("--project_dir", required=True)
    p.add_argument("--shard_index", type=int, default=1)
    p.add_argument("--total_shards", type=int, default=1)
    p.add_argument("--gemini_model", default="gemini-2.5-flash")
    p.add_argument("--sample_key", default="auto")
    p.add_argument("--mode", choices=("full", "tc", "sc"), default="full",
                   help="full = TC + SC (default); tc = TC only (no Gemini, no SC fields); "
                        "sc = SC only (no Lean REPL, no TC fields)")
    args = p.parse_args()
    do_tc = args.mode in ("full", "tc")
    do_sc = args.mode in ("full", "sc")

    input_path = Path(args.input)
    output_path = Path(args.output)
    project_dir = Path(args.project_dir)

    rows = load_jsonl(input_path)
    my_rows = [r for i, r in enumerate(rows)
               if i % args.total_shards == args.shard_index - 1]
    print(f"[supplement_tc_sc_v3] shard {args.shard_index}/{args.total_shards} "
          f"→ {len(my_rows)}/{len(rows)} rows")

    done_names = set()
    if output_path.exists():
        try:
            for r in load_jsonl(output_path):
                done_names.add(r.get("name", ""))
            print(f"[supplement_tc_sc_v3] resume: {len(done_names)} already scored")
        except Exception as e:
            print(f"[supplement_tc_sc_v3] resume-read failed: {e}; starting fresh")

    slot_base = int(os.environ.get("LEAN_SLOT_BASE", "470000"))
    basic_idx = slot_base + args.shard_index * 1000
    shard_lean_file = project_dir / "TmpProjDir" / f"Basic_{basic_idx:06d}.lean"
    shard_lean_file.parent.mkdir(parents=True, exist_ok=True)
    shard_lean_file.touch(exist_ok=True)
    print(f"[supplement_tc_sc_v3] Lean file: {shard_lean_file}")

    print(f"[supplement_tc_sc_v3] mode={args.mode}  (do_tc={do_tc} do_sc={do_sc})")

    model = None
    if do_sc:
        import google.generativeai as genai
        from number_edit.common_parser import make_gemini_model
        api_key = os.environ.get("GOOGLE_API_KEY", "")
        if not api_key:
            sys.exit("Set GOOGLE_API_KEY")
        genai.configure(api_key=api_key)
        model = make_gemini_model(args.gemini_model)
        print(f"[supplement_tc_sc_v3] Gemini: {args.gemini_model}")

    persistent_repl = None
    if do_tc:
        repl_mode = os.environ.get("LEAN_REPL_MODE", "persistent").lower()
        if repl_mode == "persistent":
            persistent_repl = PersistentLeanREPL(project_dir=project_dir)
            print(f"[supplement_tc_sc_v3] Lean REPL: persistent "
                  f"(restart every {persistent_repl.restart_every})")
        else:
            print(f"[supplement_tc_sc_v3] Lean REPL: oneshot")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "a", encoding="utf-8") as fout:
        t_start = time.time()
        total = len(my_rows)
        for i, row in enumerate(my_rows):
            name = row.get("name", f"row_{i}")
            if name in done_names:
                continue
            if args.sample_key == "auto":
                sample_keys = sorted(
                    [k for k in row.keys() if _re.fullmatch(r"LLM_Output#\d+", k)],
                    key=lambda k: int(k.split("#")[1]),
                )
            else:
                sample_keys = [args.sample_key]

            informal_stmt = row.get("informal_statement", "") or ""
            informal_proof = row.get("informal_proof", "") or ""

            enriched = dict(row)
            per_sample_summary = []
            for key in sample_keys:
                idx = key.split("#")[1]
                llm_output = str(row.get(key, "") or "")

                tc_ok = False
                if do_tc:
                    if llm_output and not llm_output.startswith("ERROR"):
                        tc_ok, tc_err = run_tc_for_row(llm_output, shard_lean_file, project_dir,
                                                        repl=persistent_repl)
                    else:
                        tc_ok, tc_err = False, "(empty or parse-error output, skipped Lean REPL)"
                    enriched[f"LLM_Syntax?#{idx}"] = "yes" if tc_ok else "no"
                    enriched[f"LLM_SyntaxError#{idx}"] = tc_err if not tc_ok else ""

                if do_sc:
                    # If TC was skipped, default tc_passes=False (only used by ValidProofSC/FullyCorrect aggregation).
                    sc = run_sc(
                        informal_statement=informal_stmt,
                        informal_proof=informal_proof,
                        generated_fl=llm_output,
                        tc_passes=tc_ok,
                        model=model,
                    )
                    for field, val in sc.items():
                        if field == "LLM_SC_details":
                            enriched[f"LLM_SC_details#{idx}"] = val
                        else:
                            enriched[f"{field}#{idx}"] = val

                bits = []
                if do_tc:
                    bits.append(f"TC={enriched.get(f'LLM_Syntax?#{idx}','?')}")
                if do_sc:
                    bits.append(f"StmtSC={enriched.get(f'LLM_StmtSC?#{idx}','?')}")
                    bits.append(f"ProofSC={enriched.get(f'LLM_ProofSC?#{idx}','?')}")
                per_sample_summary.append(f"#{idx}: " + " ".join(bits))

            fout.write(json.dumps(enriched, ensure_ascii=False) + "\n")
            fout.flush()

            elapsed = time.time() - t_start
            processed = i + 1
            rate = processed / elapsed if elapsed > 0 else 0
            eta_min = (total - processed) / rate / 60 if rate > 0 else 0
            print(f"  [{processed}/{total}] {name} ({len(sample_keys)} samples)  "
                  f"{' | '.join(per_sample_summary[:2])}{' ...' if len(per_sample_summary) > 2 else ''}  "
                  f"(rate={rate:.3f}/s ETA={eta_min:.1f}min)",
                  flush=True)

    if persistent_repl is not None:
        print(f"[supplement_tc_sc_v3] persistent REPL stats: "
              f"total={persistent_repl.total} restarts={persistent_repl.restarts}")
        persistent_repl.close()

    print(f"[supplement_tc_sc_v3] shard {args.shard_index}/{args.total_shards} done")


if __name__ == "__main__":
    main()
