# cirbench/transform/runner.py (no billing)
from __future__ import annotations
from pathlib import Path
import os, re, json, subprocess, tempfile, shlex, time, hashlib, shutil
from typing import Optional, Tuple
from ..cfg import CIRBenchConfig
from ..utils.api.base import make_runner
from ..registry import discover_cases
from ..utils.evalers import code_size_bytes, runtime_harness, llvm_mca_summary, alive_equiv, assemble_ok, verify_ok
from ..utils.case_select import select_cases
from ..utils.logging_utils import get_logger, one_line, debug_on
from ..utils.runner_common import (
    get_run_dir, get_case_dir, report_mode, resume_mode, next_missing_shot, shot_paths,
    gen_with_retries, persist_model_io, write_text, write_json, fmt_cmd_block,
    extract_ir_and_meta, atk, build_equiv_dict, sha256,
    write_repro_sh, maybe_short_circuit_prompt_only, maybe_materialize_external_io
)
#
# --- prompt template helpers (per-case) --------------------------------------
from typing import Mapping
import re as _re

def _select_mode_and_func(case_id: str, meta: dict) -> Tuple[str, Optional[str]]:
    if case_id.startswith("T004_Module"): return "module", None
    if case_id.startswith("T005_Super"):
        return "function", (meta.get("func") or meta.get("FUNC_NAME") or "kernel_run")
    return "function", "kernel_run"

def _render_prompt(mode: str, raw_ir_for_prompt: str, func_name: Optional[str]) -> str:
    """Render the final prompt text sent to the model without any MODE scaffolding."""
    if mode == "function":
        fn = func_name or "kernel_run"
        return (
            "You are an LLVM IR transform assistant.\n"
            f"- Target function: {fn}\n"
            "- Apply O3-style IR-level optimizations that LLVM would reasonably perform.\n"
            "- Preserve module-level globals, declarations, and attributes as-is.\n"
            "- If no change is needed, reply exactly: No\n\n"
            "Return format:\n"
            "- EXACTLY ONE block: <IR_OUT>...LLVM IR...</IR_OUT> (or the single word: No)\n\n"
            "Given IR:\n<IR>\n" + raw_ir_for_prompt + "\n</IR>\n\n"
            "Return ONLY the transformed function (from 'define' to the matching '}') wrapped in <IR_OUT>...</IR_OUT>."
        )
    else:  # module
        return (
            "You are an LLVM IR transform assistant.\n"
            "- Apply O3-style IR-level optimizations to the entire module.\n"
            "- Preserve declarations and non-semantic metadata that do not affect code generation.\n"
            "- If no change is needed, reply exactly: No\n\n"
            "Return format:\n"
            "- EXACTLY ONE block: <IR_OUT>...LLVM IR...</IR_OUT> (or the single word: No)\n\n"
            "Given IR:\n<IR>\n" + raw_ir_for_prompt + "\n</IR>\n\n"
            "Return ONLY the transformed module wrapped in <IR_OUT>...</IR_OUT>."
        )

def _extract_function_ir(module_ir: str, func_name: str) -> Tuple[str, str]:
    import re

    if not func_name:
        return "", "empty_func_name"

    text = module_ir.replace("\r\n", "\n")

    name = re.escape(func_name)
    pattern = r'^[ \t]*define\b[^\n]*@(?:"%s"|%s)\s*\(' % (name, name)

    m = re.search(pattern, text, flags=re.M)
    if not m:
        return "", f"func_not_found:{func_name}"

    brace_pos = text.find("{", m.end())
    if brace_pos == -1:
        return "", "no_open_brace"

    depth = 0
    i = brace_pos
    n = len(text)
    while i < n:
        ch = text[i]
        if ch == "{":
            depth += 1
        elif ch == "}":
            depth -= 1
            if depth == 0:
                return text[m.start():i+1], ""
        i += 1

    return "", "no_match_close_brace"


# --- Helper: fallback IR extractor for malformed model outputs ---
def _salvage_ir_from_text(text: str) -> Tuple[Optional[str], bool]:
    """
    Fallback extractor for malformed model outputs (e.g., missing </IR_OUT> or no tags).
    Heuristics:
      1) Prefer content after '<IR_OUT>' (case-insensitive); else use full text.
      2) Strip leading/backtick code fences like ``` or ```llvm.
      3) Trim trailing code fences.
      4) Truncate at the last '}' to close the function/module.
      5) Accept only if it contains 'define' and brace counts match.
    Returns (ir_out, ok).
    """
    if not text:
        return None, False
    txt = text

    # Take content after <IR_OUT> if present
    m = re.search(r'&lt;IR_OUT&gt;|<IR_OUT>', txt, flags=re.I | re.S)
    if m:
        cand = txt[m.end():]
    else:
        cand = txt

    # Strip leading fenced code blocks
    cand = cand.lstrip()
    cand = re.sub(r'^```(?:llvm|ll|ir)?\s*', '', cand, flags=re.I)

    # Drop trailing backticks if any
    cand = re.sub(r'\s*```[\s\r\n]*$', '', cand, flags=re.S)

    # Truncate at the last '}' so we end with a closed body
    j = cand.rfind('}')
    if j != -1:
        cand = cand[:j+1]

    cand = cand.strip()

    # Quick sanity checks
    if 'define' not in cand:
        return None, False
    # Balance braces (coarse but robust for IR bodies/modules)
    if cand.count('{') == 0 or cand.count('{') != cand.count('}'):
        return None, False

    return cand, True


# --- Helper: normalize function name inside IR text ---
def _normalize_function_name(ir_text: str, target_name: str) -> str:
    """Ensure the IR function body defines `@target_name` and updates self-refs.
    Accepts both @name and @"name" spellings. If target is empty or no define is
    found, returns input unchanged.
    """
    if not ir_text or not target_name:
        return ir_text
    # Match the function name on the first define line
    m = re.search(r'(?m)^\s*define\b[^@]*@(?:"([^"]+)"|([A-Za-z0-9_.$-]+))\s*\(', ir_text)
    if not m:
        return ir_text
    old = m.group(1) or m.group(2) or ""
    if not old or old == target_name:
        return ir_text
    # Rename the define symbol itself
    ir2 = re.sub(r'(?m)^(\s*define\b[^@]*@)(?:"[^"]+"|[A-Za-z0-9_.$-]+)', r'\1' + target_name, ir_text, count=1)
    # Rewrite any intra-body references to the old symbol (quoted or unquoted)
    pat = re.compile(r'@(?:"' + re.escape(old) + r'"|' + re.escape(old) + r')(?=[^A-Za-z0-9_.$-]|$)')
    ir2 = re.sub(pat, '@' + target_name, ir2)
    return ir2


# ---- build & run helpers ----------------------------------------------------

def _bench_utils_obj() -> Path:
    """Absolute path to bench_utils.o located under cirbench/transform/bench."""
    return (Path(__file__).parent / "bench" / "bench_utils.o").resolve()

def _ldflags_from_env() -> list[str]:
    import shlex as _sh
    v = os.getenv("CIRBENCH_LDFLAGS", "").strip()
    return _sh.split(v) if v else []

def _compile_ir_to_exe(ir_text: str, out_dir: Path, base: str, cfg: CIRBenchConfig, *, needs_cxx: bool = False) -> dict:
    """Write <base>.ll → llc -O3 → <base>.o → link (clang or clang++ when C++ runtime is needed).
    Returns {ok,cmd,rc,out,err,exe,ll}.
    Rationale: avoid re-running middle-end O2 on .ll; drive codegen via llc -O3 only.
    """
    from ..utils.runner_common import write_text as _wt
    ll = out_dir / f"{base}.ll"
    obj = out_dir / f"{base}.o"
    exe = out_dir / f"{base}.bin"
    _wt(ll, ir_text)

    llc = getattr(getattr(cfg, "toolchain", object()), "llc", None) or "llc"
    clang = getattr(getattr(cfg, "toolchain", object()), "clang", None) or "clang"
    bench_obj = _bench_utils_obj()

    # Choose linker: use clang++ when C++ runtime is required (iostreams, etc.)
    clangxx = getattr(getattr(cfg, "toolchain", object()), "clangxx", None)
    if not clangxx:
        # Derive from clang path if possible, otherwise fall back to PATH
        if isinstance(clang, str) and clang.endswith("clang"):
            cand = clang + "+"
            clangxx = cand if (shutil.which(cand) or Path(cand).exists()) else None
    if not clangxx:
        clangxx = shutil.which("clang++") or "clang++"
    linker = (clangxx if needs_cxx else clang)

    # 1) llc -O3 to object
    # Allow injecting extra llc flags from env to optionally emit PIC objects if desired.
    # Rationale: some toolchains/platforms require -relocation-model=pic for shared libs, but default to non-PIC for faster code and easier linking.
    extra_llc = os.getenv("CIRBENCH_LLC_FLAGS", "").strip()
    extra_llc_list = extra_llc.split() if extra_llc else []
    llc_cmd = [llc, "-O3", "-filetype=obj", *extra_llc_list, str(ll), "-o", str(obj)]
    p1 = subprocess.run(llc_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    rc1 = p1.returncode
    out1 = p1.stdout.decode(errors="ignore")
    err1 = p1.stderr.decode(errors="ignore")
    if rc1 != 0 or not obj.exists():
        return {"ok": False, "cmd": llc_cmd, "rc": rc1, "out": out1, "err": err1, "exe": exe, "ll": ll}

    # 2) Link (no -O flags here), always link bench_utils.o first
    # Default: disable PIE to avoid 'R_X86_64_32 against `.rodata` when making a PIE object'
    ldflags = _ldflags_from_env()
    link_no_pie = (os.getenv("CIRBENCH_LINK_NO_PIE", "1").strip() != "0")
    base_cmd = [linker]
    if link_no_pie:
        base_cmd.append("-no-pie")
    base_cmd += [str(obj), str(bench_obj)]
    cmd = base_cmd + ldflags + ["-o", str(exe)]
    p2 = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    rc2 = p2.returncode
    out2 = p2.stdout.decode(errors="ignore")
    err2 = p2.stderr.decode(errors="ignore")

    # Retry with -lm on link failure
    if rc2 != 0 and all(flag != "-lm" for flag in ldflags):
        cmd2 = base_cmd + ldflags + ["-lm", "-o", str(exe)]
        p3 = subprocess.run(cmd2, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        rc3 = p3.returncode
        out3 = p3.stdout.decode(errors="ignore")
        err3 = p3.stderr.decode(errors="ignore")
        if rc3 == 0:
            return {"ok": True, "cmd": cmd2, "rc": rc3, "out": out3, "err": err3, "exe": exe, "ll": ll}
        else:
            return {"ok": False, "cmd": cmd2, "rc": rc3, "out": out3, "err": err2 + "\n--- retry with -lm ---\n" + err3, "exe": exe, "ll": ll}

    return {"ok": rc2 == 0, "cmd": cmd, "rc": rc2, "out": out2, "err": err2, "exe": exe, "ll": ll}

# Emit assembly from a given .ll path
def _emit_asm_from_ll(ll_path: Path, out_dir: Path, base: str, cfg: CIRBenchConfig) -> dict:
    """Compile <base>.s from an existing .ll using llc -O3 -filetype=asm. Returns
    {ok, cmd, rc, out, err, asm}. Does not link, so unresolved externals are fine.
    """
    s = out_dir / f"{base}.s"
    llc = getattr(getattr(cfg, "toolchain", object()), "llc", None) or "llc"
    cmd = [llc, "-O3", "-filetype=asm", str(ll_path), "-o", str(s)]
    p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    rc = p.returncode
    out = p.stdout.decode(errors="ignore")
    err = p.stderr.decode(errors="ignore")
    return {"ok": rc == 0, "cmd": cmd, "rc": rc, "out": out, "err": err, "asm": s}

def _run_executable(exe: Path, args: list[str], *, stdin_path: Optional[Path] = None, timeout_s: int = 20) -> dict:
    """Run an executable and capture stdout. Returns {ok,rc,out,err,time_ms,cmd}."""
    t0 = time.perf_counter()
    try:
        if stdin_path and stdin_path.exists():
            with open(stdin_path, "rb") as fin:
                p = subprocess.run([str(exe), *args], input=fin.read(), stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout_s)
        else:
            p = subprocess.run([str(exe), *args], stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout_s)
        t1 = time.perf_counter()
        return {
            "ok": p.returncode == 0,
            "rc": p.returncode,
            "out": p.stdout.decode(errors="ignore"),
            "err": p.stderr.decode(errors="ignore"),
            "time_ms": int((t1 - t0) * 1000),
            "cmd": [str(exe), *args] + ([f"< {stdin_path}"] if stdin_path else []),
        }
    except subprocess.TimeoutExpired:
        t1 = time.perf_counter()
        return {"ok": False, "rc": None, "out": "", "err": f"timeout({timeout_s}s)", "time_ms": int((t1 - t0) * 1000), "cmd": [str(exe), *args]}

def _taskset_prefix() -> list[str]:
    """
    Optional 'taskset' binding. Controlled by env CIRBENCH_TASKSET (default '0').
    If 'taskset' is not present, returns [].
    """
    cpus = os.getenv("CIRBENCH_TASKSET", "0").strip()
    ts = shutil.which("taskset")
    return ([ts, "-c", cpus] if ts else [])

# ---- Simple timing helper for T005 ----
def _time_executable_simple(
    exe: Path,
    args: list[str],
    *,
    stdin_path: Optional[Path] = None,
    timeout_s: int = 60
) -> dict:
    """
    Minimalistic wall-clock timer:
      - Runs the program REPEATEDLY, discarding stdout/stderr, and sums elapsed time.
      - Repeats scale by x10 each iteration until aggregated time >= target_ms.
      - CPU pinning via `taskset -c` is honored if CIRBENCH_TASKSET is set.
    Env knobs:
      CIRBENCH_SIMPLE_TARGET_MS   (default 500)
      CIRBENCH_SIMPLE_BASE_REPEATS(default 1)
      CIRBENCH_SIMPLE_SCALE       (default 10)
      CIRBENCH_SIMPLE_MAX_ITERS   (default 6)
      CIRBENCH_SIMPLE_MAX_REPEATS (default 100000)
    Returns: {ok, rc, out, err, time_ms, cmd}
    """
    # Read stdin once if provided
    stdin_bytes = None
    if stdin_path and Path(stdin_path).exists():
        try:
            with open(stdin_path, "rb") as f:
                stdin_bytes = f.read()
        except Exception:
            stdin_bytes = None

    # Configs
    try:
        target_ms = float(os.getenv("CIRBENCH_SIMPLE_TARGET_MS", "500"))
    except Exception:
        target_ms = 500.0
    try:
        base_repeats = int(os.getenv("CIRBENCH_SIMPLE_BASE_REPEATS", "1"))
    except Exception:
        base_repeats = 1
    try:
        scale = int(os.getenv("CIRBENCH_SIMPLE_SCALE", "10"))
    except Exception:
        scale = 10
    try:
        max_iters = int(os.getenv("CIRBENCH_SIMPLE_MAX_ITERS", "6"))
    except Exception:
        max_iters = 6
    try:
        max_repeats = int(os.getenv("CIRBENCH_SIMPLE_MAX_REPEATS", "100000"))
    except Exception:
        max_repeats = 100000

    repeats = max(1, base_repeats)
    prefix = _taskset_prefix()
    base_cmd = prefix + [str(exe), *args]
    last_err = ""
    last_rc = 0

    for _ in range(max_iters):
        # Run `repeats` times and sum wall-clock
        t0 = time.perf_counter()
        ok = True
        for _i in range(repeats):
            try:
                p = subprocess.run(
                    base_cmd,
                    input=stdin_bytes,
                    stdout=subprocess.DEVNULL,   # discard to avoid pipe backpressure
                    stderr=subprocess.DEVNULL,   # discard; correctness is checked elsewhere
                    timeout=timeout_s
                )
                rc = p.returncode
            except subprocess.TimeoutExpired:
                agg_ms = int(round((time.perf_counter() - t0) * 1000))
                return {
                    "ok": False,
                    "rc": None,
                    "out": "",
                    "err": f"timeout({timeout_s}s)",
                    "time_ms": None,
                    "cmd": base_cmd + [f"x{repeats}"],
                }
            if rc != 0:
                ok = False
                last_rc = rc
                last_err = f"nonzero-exit:{rc}"

        agg_ms = int(round((time.perf_counter() - t0) * 1000))
        # Stop when window is big enough to be meaningful
        if agg_ms >= target_ms:
            return {
                "ok": ok,
                "rc": (0 if ok else (last_rc or 1)),
                "out": "",
                "err": last_err,
                "time_ms": max(1, agg_ms),
                "cmd": base_cmd + [f"x{repeats}"],
            }

        # Escalate repeats (x10) and loop
        if repeats >= max_repeats:
            break
        repeats = min(repeats * scale, max_repeats)

    # If we exit the loop without reaching the target, return the last measurement
    return {
        "ok": True,
        "rc": 0,
        "out": "",
        "err": "",
        "time_ms": max(1, agg_ms),
        "cmd": base_cmd + [f"x{repeats}"],
    }

def _norm_text(s: str) -> str:
    return (s or "").strip()

def _checksumeq(a: str, b: str) -> bool:
    return bool(a) and bool(b) and (hashlib.sha256(_norm_text(a).encode("utf-8")).hexdigest() == hashlib.sha256(_norm_text(b).encode("utf-8")).hexdigest())


# Helper to extract benchmark-reported checksum and time from output
def _extract_checksum_and_time(text: str) -> Tuple[Optional[str], Optional[int]]:
    """
    Parse benchmark line like:
    T001_Loops_014,K22,case1,L1,checksum=0x68f50a75b513f733,time=0.000139,loops=100

    Returns (checksum_hex_lower, time_ms) or (None, None) if not found.
    """
    import re
    if not text:
        return None, None
    m = re.search(r'checksum=0x([0-9a-fA-F]+).*?time=([0-9]*\.?[0-9]+)', text, flags=re.S)
    if not m:
        return None, None
    ck = m.group(1).lower()
    try:
        sec = float(m.group(2))
        ms = int(round(sec * 1000))
    except Exception:
        ms = None
    return ck, ms

def _summary_flags(shots):
    def flag(key, k): return bool(atk(shots, k, key))
    return {
        "pass_at_1": flag("pass",1), "pass_at_5": flag("pass",5),
        "valid_at_1": flag("valid",1), "valid_at_5": flag("valid",5),
        "equiv_at_1": flag("equiv",1), "equiv_at_5": flag("equiv",5),
    }

def run_task(cfg: CIRBenchConfig, proj_root: Path):
    logger = get_logger()
    tdir = proj_root / "cirbench" / "transform"
    cases = select_cases(discover_cases(tdir))
    if not cases:
        print("No transform cases found."); return
    from ..utils.api.base import make_runner, select_model_cfg
    runner_cfg = select_model_cfg(cfg)
    runner = make_runner(runner_cfg)
    run_dir = get_run_dir(proj_root)
    env_mode = (os.getenv("CIRBENCH_TRANSFORM_MODE") or "both").strip().lower()  # normal|copilot|both
    run_modes = (["normal", "copilot"] if env_mode in ("both", "all", "dual", "") else [env_mode])

    for case in cases:
        case_dirs = {m: get_case_dir(run_dir, f"transform.{m}", case.id) for m in run_modes}
        meta = case.meta
        module_ir = case.raw_ir.read_text(encoding="utf-8")
        # Select mode and function name for this case
        mode, func_name = _select_mode_and_func(case.id, case.meta)

        # Build prompt strictly based on MODE selection:
        # - function (T001/T002/T003/T005): only feed the target function body to the LLM
        # - module (T004): feed the whole module
        # The replacement/apply stage later still uses the full module_ir.
        if mode == "function":
            _fn = func_name or "kernel_run"
            _fn_ir, _fn_err = _extract_function_ir(module_ir, _fn)
            prompt_ir = _fn_ir if _fn_ir else module_ir
        else:
            prompt_ir = module_ir

        prompt = _render_prompt(mode, prompt_ir, func_name)
        psha = sha256(prompt)
        if report_mode():
            for _m in run_modes:
                d = case_dirs[_m]
                write_text(d/"prompt.txt", prompt)
                write_text(d/"prompt.sha256", psha)

        # === Prompt-only: write prompt and stop before calling LLM ===
        _early = False
        for _m in run_modes:
            _sp = shot_paths(case_dirs[_m], 1); _sp["dir"].mkdir(parents=True, exist_ok=True)
            if maybe_short_circuit_prompt_only(cfg, _sp["dir"], prompt):
                try:
                    write_json(_sp["dir"]/ "metrics.json", {"case_id": case.id, "mode": _m, "shots": []})
                except Exception:
                    pass
                _early = True
        if _early:
            print(f"{case.id}: prompt-only (no LLM)")
            continue

        start_k = 1
        if resume_mode():
            mks = [next_missing_shot(case_dirs[_m], 5) for _m in run_modes]
            start_k = min(mks) if mks else 1
            if all(mk > 5 for mk in mks):
                print(f"[transform] skip {case.id} (already 1..5 for all modes)")
                for _m in run_modes:
                    mdir = shot_paths(case_dirs[_m], 1)["dir"]; mdir.mkdir(parents=True, exist_ok=True)
                    write_json(mdir/"metrics.json", {"case_id": case.id, "mode": _m, "shots": [], "summary": _summary_flags([])})
                continue

        shots_by_mode = {m: [] for m in run_modes}
        for k in range(start_k, 6):

            # persist raw & golden IR snapshot for this shot (for fair diff)
            for _m in run_modes:
                _sp = shot_paths(case_dirs[_m], k); _sp["dir"].mkdir(parents=True, exist_ok=True)
                try:
                    write_text(_sp["dir"]/"raw.ll", module_ir)
                except Exception:
                    pass
                try:
                    if getattr(case, "gold", None) and case.gold.exists():
                        write_text(_sp["dir"]/"golden.ll", case.gold.read_text(encoding="utf-8"))
                except Exception:
                    pass

            # === Prefer external artifacts when provided (from-files) ===
            using_pred = False
            variant_ir = None
            ir_out, ierr, meta_json = None, False, {}
            meta_out = {}
            src_text = None

            # Use the first mode's shot dir to materialize external files; they will be copied/consumed per-mode later.
            _sp0 = shot_paths(case_dirs[run_modes[0]], k); _sp0["dir"].mkdir(parents=True, exist_ok=True)
            _mode_ext = maybe_materialize_external_io(cfg, case.id, _sp0["dir"])

            if _mode_ext == "resp":
                try:
                    src_text = (_sp0["dir"]/ "model.resp.txt").read_text(encoding="utf-8")
                except Exception:
                    src_text = ""
                ir_out, ierr, meta_json = extract_ir_and_meta(src_text)
                if (ierr or not ir_out) and (src_text or "").strip():
                    cand, ok = _salvage_ir_from_text(src_text)
                    if ok and cand:
                        ir_out = cand
                        ierr = False
                meta_out = {"source": "from-files", "kind": "resp"}
                if report_mode():
                    for _m in run_modes:
                        persist_model_io(case_dirs[_m], k, src_text or "", meta_out)

            elif _mode_ext == "pred":
                # External IR already provided as variant.ll under the first mode's shot dir
                _var = _sp0["variant"]
                if _var.exists():
                    variant_ir = _var.read_text(encoding="utf-8")
                    using_pred = True
                    ierr = False
                    meta_out = {"source": "from-files", "kind": "pred"}

            if not (_mode_ext in ("resp", "pred")):
                # Fall back to live LLM
                from ..utils.api.base import make_runner, select_model_cfg
                runner_cfg = select_model_cfg(cfg)
                runner = make_runner(runner_cfg)
                hint = None
                for fn in ("golden.ll",):
                    p = getattr(case, fn.replace(".", "_"), None)
                    if p and p.exists():
                        hint = p.read_text(encoding="utf-8"); break

                outc = gen_with_retries(runner, prompt)
                src_text = outc.text
                ir_out, ierr, meta_json = extract_ir_and_meta(outc.text)
                # Fallback: tolerate missing </IR_OUT> or no tags at all
                if (ierr or not ir_out) and (outc.text or "").strip():
                    cand, ok = _salvage_ir_from_text(outc.text)
                    if ok and cand:
                        ir_out = cand
                        ierr = False
                meta_out = outc.meta or {}
                if report_mode():
                    for _m in run_modes:
                        persist_model_io(case_dirs[_m], k, outc.text, meta_out)

            accepted = False
            model_full_ir = ""
            reject_reason = ""

            # If external 'pred' supplied a full IR, short-circuit acceptance here.
            if using_pred and variant_ir:
                accepted = True
                model_full_ir = variant_ir
                ierr = False
                reject_reason = ""
            else:
                _txt_for_check = (src_text or "")
                if _txt_for_check.strip().lower() == "no":
                    ierr = True
                    reject_reason = "model_said_no"
                elif not ierr and ir_out:
                    if mode == "function":
                        old_func_ir, e2 = _extract_function_ir(module_ir, func_name or "")
                        if (not e2) and old_func_ir:
                            ir_norm = _normalize_function_name(ir_out, (func_name or ""))
                            model_full_ir = module_ir.replace(old_func_ir, ir_norm)
                            accepted = (model_full_ir.strip() != module_ir.strip())
                            if not accepted:
                                reject_reason = "no_textual_change_after_replacement"
                        else:
                            reject_reason = f"target_not_found:{func_name}"
                    else:
                        model_full_ir = ir_out
                        accepted = (model_full_ir.strip() != module_ir.strip())
                        if not accepted:
                            reject_reason = "no_change_module"
                else:
                    ierr = True
                    reject_reason = "parse_error_or_empty_ir_out"

            # Prepare per-mode pipeline IRs (normal: direct; copilot: optional O3 on model IR)
            pipeline_ir_by_mode = {}
            if accepted:
                for _m in run_modes:
                    if _m == "copilot" and getattr(cfg.toolchain, "opt", None):
                        with tempfile.TemporaryDirectory() as _td:
                            _td = Path(_td)
                            _src = _td/"in.ll"; _outll = _td/"o3.ll"
                            _src.write_text(model_full_ir, encoding="utf-8")
                            _cmd = [cfg.toolchain.opt, "-passes=default<O3>", "-S", str(_src), "-o", str(_outll)]
                            _p = subprocess.run(_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
                            if _p.returncode == 0 and _outll.exists():
                                pipeline_ir_by_mode[_m] = _outll.read_text(encoding="utf-8")
                            else:
                                pipeline_ir_by_mode[_m] = model_full_ir
                    else:
                        pipeline_ir_by_mode[_m] = model_full_ir
            else:
                for _m in run_modes:
                    pipeline_ir_by_mode[_m] = module_ir

            # Detect T005 cases early for special handling (stdout/expected-based final equivalence; still run Alive for diagnostics)
            is_t005_case = case.id.startswith("T005_")

            # === Single Alive check on raw model IR (before any copilot O3) ===
            base_alive = {"status": "skip"}
            base_alive_ok = None
            base_alive_tmo = False
            if (not case.id.startswith("T004_Module")):
                if accepted:
                    # Only run Alive if the model_full_ir is syntactically valid
                    _am = assemble_ok(model_full_ir, cfg)
                    _vm = verify_ok(model_full_ir, cfg) if _am.get("ok") else {"ok": False}
                    if _am.get("ok") and _vm.get("ok"):
                        # Dump under the first mode's artifacts dir
                        _sp_base = shot_paths(case_dirs[run_modes[0]], k); _sp_base["dir"].mkdir(parents=True, exist_ok=True)
                        base_alive = alive_equiv(module_ir, model_full_ir, cfg, func=(func_name or meta.get("func")), dump_path=_sp_base["alive_merged"])  # src=baseline, tgt=model_full_ir
                        st = str(base_alive.get("status") or "")
                        ar = base_alive.get("equiv")
                        base_alive_ok = True if ar is True else (False if ar is False else None)
                        if st in ("timeout", "tmo"):
                            base_alive_tmo = True
                    else:
                        base_alive = {"status": "invalid_ir", "note": "model_full_ir failed assemble/verify"}
                        base_alive_ok = False
                        base_alive_tmo = False

            # For each mode, compile/run/evaluate into its own folder
            any_equiv_true = False
            for _m in run_modes:
                _case_dir = case_dirs[_m]
                _sp = shot_paths(_case_dir, k); _sp["dir"].mkdir(parents=True, exist_ok=True)

                # 1) Build executables in artifacts dir
                is_t005 = case.id.startswith("T005_")
                build_raw    = _compile_ir_to_exe(module_ir, _sp["dir"], "raw", cfg, needs_cxx=is_t005)
                build_golden = None
                if getattr(case, "gold", None) and case.gold.exists():
                    build_golden = _compile_ir_to_exe(case.gold.read_text(encoding="utf-8"), _sp["dir"], "golden", cfg, needs_cxx=is_t005)
                build_var    = _compile_ir_to_exe(pipeline_ir_by_mode[_m], _sp["dir"], "variant", cfg, needs_cxx=is_t005) if accepted else None

                # 2) Decide run interface
                stdin_path = (case.root/"input.txt") if is_t005 else None
                expected_out = (case.root/"output.txt") if is_t005 else None
                args = [] if is_t005 else [str(case.meta.get("case", 1)), str(case.meta.get("iter", 3))]

                # 3) Run and collect timings
                if is_t005:
                    # Phase 1: correctness (capture stdout for expected/stdout equality)
                    chk_raw  = _run_executable(build_raw.get("exe"), args, stdin_path=stdin_path) if build_raw.get("ok") else {"ok": False}
                    chk_gold = _run_executable(build_golden.get("exe"), args, stdin_path=stdin_path) if (build_golden and build_golden.get("ok")) else {"ok": False}
                    chk_var  = _run_executable(build_var.get("exe"), args, stdin_path=stdin_path) if (build_var and build_var.get("ok")) else {"ok": False}

                    # Phase 2: timing (discard stdout; repeats ×10 until >= target_ms)
                    tim_raw  = _time_executable_simple(build_raw.get("exe"), args, stdin_path=stdin_path) if build_raw.get("ok") else {"ok": False}
                    tim_gold = _time_executable_simple(build_golden.get("exe"), args, stdin_path=stdin_path) if (build_golden and build_golden.get("ok")) else {"ok": False}
                    tim_var  = _time_executable_simple(build_var.get("exe"), args, stdin_path=stdin_path) if (build_var and build_var.get("ok")) else {"ok": False}

                    # Merge: keep stdout from correctness, time_ms/command from timing
                    def _merge_runs(chk: dict, tim: dict) -> dict:
                        if not isinstance(chk, dict): chk = {"ok": False}
                        if not isinstance(tim, dict): tim = {}
                        out = dict(chk)
                        out["time_ms"] = tim.get("time_ms")
                        out["cmd"] = tim.get("cmd", chk.get("cmd"))
                        if tim.get("err"):
                            out["err"] = ((chk.get("err","") + ("\n--- timing ---\n" + tim["err"])) if chk.get("err") else tim["err"])
                        return out

                    run_raw  = _merge_runs(chk_raw,  tim_raw)
                    run_gold = _merge_runs(chk_gold, tim_gold)
                    run_var  = _merge_runs(chk_var,  tim_var)
                else:
                    run_raw  = _run_executable(build_raw.get("exe"), args, stdin_path=stdin_path) if build_raw.get("ok") else {"ok": False}
                    run_gold = _run_executable(build_golden.get("exe"), args, stdin_path=stdin_path) if (build_golden and build_golden.get("ok")) else {"ok": False}
                    run_var  = _run_executable(build_var.get("exe"), args, stdin_path=stdin_path) if (build_var and build_var.get("ok")) else {"ok": False}

                # Persist outputs for diff
                try:
                    if run_raw.get("out") is not None: write_text(_sp["dir"]/"raw.out", run_raw.get("out",""))
                    if run_gold.get("out") is not None: write_text(_sp["dir"]/"golden.out", run_gold.get("out",""))
                    if run_var.get("out") is not None: write_text(_sp["dir"]/"variant.out", run_var.get("out",""))
                except Exception:
                    pass

                # assemble/verify on this mode's pipeline IR to record diagnostics
                if accepted:
                    a_res = assemble_ok(pipeline_ir_by_mode[_m], cfg)
                    v_res = verify_ok(pipeline_ir_by_mode[_m], cfg) if a_res.get("ok") else {"ok": False, "rc": 1, "stdout":"", "stderr":"", "cmd":""}
                else:
                    a_res = assemble_ok(module_ir, cfg)
                    v_res = verify_ok(module_ir, cfg) if a_res.get("ok") else {"ok": False, "rc": 1, "stdout":"", "stderr":"", "cmd":""}

                # If the variant IR is not syntactically valid, mark as invalid and skip equivalence/Alive decisions
                invalid_variant = (accepted and (not a_res.get("ok") or not v_res.get("ok")))

                # Parse benchmark‑reported checksum/time from program outputs
                raw_ck, raw_bench = (_extract_checksum_and_time(run_raw.get("out", "")) if run_raw.get("ok") else (None, None))
                gld_ck, gld_bench = (_extract_checksum_and_time(run_gold.get("out", "")) if (run_gold and run_gold.get("ok")) else (None, None))
                var_ck, var_bench = (_extract_checksum_and_time(run_var.get("out", "")) if (run_var and run_var.get("ok")) else (None, None))

                # Skip runs cleanly when variant build failed (invalid_variant): force run_var = {"ok": False}
                if invalid_variant:
                    run_var = {"ok": False}

                raw_vs_gld = (raw_ck is not None and gld_ck is not None and raw_ck == gld_ck)
                var_vs_gld = (var_ck is not None and gld_ck is not None and var_ck == gld_ck)
                var_vs_raw = (var_ck is not None and raw_ck is not None and var_ck == raw_ck)
                all_three  = (raw_ck is not None and gld_ck is not None and var_ck is not None and raw_ck == gld_ck == var_ck)

                # Local holder for Alive2 result (for logging)
                ae_log = {}
                # Equivalence decision
                if case.id.startswith("T005_"):
                    # T005: equivalence by program stdout (or expected file), but still run Alive for diagnostics/logging
                    if accepted:
                        if invalid_variant:
                            eqv = build_equiv_dict("invalid_ir", False, method="verify", alive_timeout=False, checksum_equal=False)
                            ae_log = {}
                        else:
                            # Attach Alive diagnostics (already computed above as base_alive)
                            ae_log = base_alive if isinstance(base_alive, dict) else {}
                            alive_tmo = bool(base_alive_tmo)
                            # Gather stdout for equality; normalize via sha256-based _checksumeq
                            raw_txt = (run_raw.get("out", "") if run_raw else "")
                            gld_txt = (run_gold.get("out", "") if run_gold else "")
                            var_txt = (run_var.get("out", "") if run_var else "")
                            try:
                                exp_txt = expected_out.read_text(encoding="utf-8") if (expected_out and expected_out.exists()) else None
                            except Exception:
                                exp_txt = None
                            def _eq(a,b):
                                return _checksumeq(a, b)
                            # Derive equality flags based on available outputs
                            raw_ok = bool(run_raw and run_raw.get("ok"))
                            gld_ok = bool(run_gold and run_gold.get("ok"))
                            var_ok = bool(run_var and run_var.get("ok"))
                            raw_vs_gld = (_eq(raw_txt, gld_txt) if (raw_ok and gld_ok) else False)
                            var_vs_gld = (_eq(var_txt, gld_txt) if (var_ok and gld_ok) else False)
                            var_vs_raw = (_eq(var_txt, raw_txt) if (var_ok and raw_ok) else False)
                            all_three  = (raw_ok and gld_ok and var_ok and _eq(raw_txt, gld_txt) and _eq(var_txt, raw_txt))
                            # Prefer explicit expected output when provided
                            if (exp_txt is not None) and var_ok and _eq(var_txt, exp_txt):
                                eqv = build_equiv_dict("ok", True, method="expected", alive_timeout=alive_tmo, checksum_equal=True)
                            elif all_three or (raw_vs_gld and var_vs_gld):
                                eqv = build_equiv_dict("ok", True, method="stdout", alive_timeout=alive_tmo, checksum_equal=True)
                            else:
                                eqv = build_equiv_dict("checksum_mismatch", False, method="stdout", alive_timeout=alive_tmo, checksum_equal=False)
                    else:
                        eqv = build_equiv_dict("checksum_mismatch", False, method="stdout", alive_timeout=bool(base_alive_tmo), checksum_equal=False)
                elif not case.id.startswith("T004_Module"):
                    if accepted:
                        if invalid_variant:
                            eqv = build_equiv_dict("invalid_ir", False, method="verify", alive_timeout=False, checksum_equal=False)
                            ae_log = {}
                        else:
                            if _m == "normal":
                                # Use the single Alive result computed on raw model_full_ir
                                ae = base_alive
                                ae_log = ae if isinstance(ae, dict) else {}
                                alive_ok = base_alive_ok
                                alive_tmo = base_alive_tmo
                                method = "alive"
                            else:  # copilot inherits normal's Alive result; do not call Alive again
                                ae = {"status": "skip"}
                                ae_log = {}  # do not log Alive2 for copilot
                                alive_ok = base_alive_ok
                                alive_tmo = base_alive_tmo
                                method = "alive(inherited)" if base_alive_ok is not None else "checksum"

                            if alive_ok is True:
                                eqv = build_equiv_dict("ok", True, method=method, alive_timeout=alive_tmo, checksum_equal=all_three)
                            elif alive_ok is False:
                                eqv = build_equiv_dict("ok", False, method=method, alive_timeout=alive_tmo, checksum_equal=all_three)
                            elif all_three:
                                eqv = build_equiv_dict("ok", True, method="checksum", alive_timeout=alive_tmo, checksum_equal=True)
                            else:
                                eqv = build_equiv_dict("checksum_mismatch", False, alive_timeout=alive_tmo, checksum_equal=False)
                    else:
                        # Not accepted — no variant to compare; rely on checksums if available
                        eqv = build_equiv_dict("checksum_mismatch", False, checksum_equal=False)
                else:
                    # T004: rely on checksums only
                    eqv = build_equiv_dict("ok", True, method="checksum", checksum_equal=True) if all_three else \
                          build_equiv_dict("checksum_mismatch", False, method="checksum", checksum_equal=False)
                    ae_log = {}

                # Codesize
                bsz = code_size_bytes(module_ir, cfg)
                vsz = code_size_bytes(pipeline_ir_by_mode[_m], cfg) if accepted else None
                size_ratio = (bsz / vsz) if (isinstance(bsz,int) and isinstance(vsz,int) and vsz and vsz>0) else None

                # Runtime summary — only record when equivalence established
                variant_ok = (eqv.get("ok") is True)
                if is_t005:
                    t_raw = (run_raw.get("time_ms") if (variant_ok and run_raw) else None)
                    t_gld = (run_gold.get("time_ms") if (variant_ok and run_gold) else None)
                    t_var = (run_var.get("time_ms") if (variant_ok and run_var) else None)
                else:
                    t_raw = raw_bench if variant_ok else None
                    t_gld = gld_bench if variant_ok else None
                    t_var = var_bench if variant_ok else None

                sp_gld_raw   = (t_raw / t_gld) if (t_raw and t_gld and t_gld > 0) else None
                sp_model_raw = (t_raw / t_var) if (variant_ok and t_raw and t_var and t_var > 0) else None
                sp_model_gld = (t_gld / t_var) if (variant_ok and t_gld and t_var and t_var > 0) else None

                # llvm-mca: collect micro-architectural summaries only when equivalence is established
                mca = {"raw": {"status": "skip"}, "golden": {"status": "skip"}, "variant": {"status": "skip"}}
                try:
                    if variant_ok:
                        # Emit assembly from the exact .ll files used above
                        if build_raw and build_raw.get("ok") and build_raw.get("ll"):
                            _asm_raw = _emit_asm_from_ll(build_raw.get("ll"), _sp["dir"], "raw", cfg)
                            if _asm_raw.get("ok"):
                                mca["raw"] = llvm_mca_summary(_asm_raw.get("asm"), cfg)
                        if build_golden and build_golden.get("ok") and build_golden.get("ll"):
                            _asm_gld = _emit_asm_from_ll(build_golden.get("ll"), _sp["dir"], "golden", cfg)
                            if _asm_gld.get("ok"):
                                mca["golden"] = llvm_mca_summary(_asm_gld.get("asm"), cfg)
                        if build_var and build_var.get("ok") and build_var.get("ll"):
                            _asm_var = _emit_asm_from_ll(build_var.get("ll"), _sp["dir"], "variant", cfg)
                            if _asm_var.get("ok"):
                                mca["variant"] = llvm_mca_summary(_asm_var.get("asm"), cfg)
                except Exception:
                    # keep 'skip' on failure; avoid crashing the main pipeline
                    pass

                shot = {
                    "k": k,
                    "pass": bool(a_res.get("ok")),
                    "valid": bool(v_res.get("ok")),
                    "equiv": eqv,
                    "runtime": {
                        "t_raw_ms": t_raw, "t_golden_ms": t_gld, "t_variant_ms": t_var,
                        "speedup_golden_raw": sp_gld_raw,
                        "speedup_model_raw": sp_model_raw, "speedup_model_golden": sp_model_gld,
                        "checksum": {"raw": raw_ck, "golden": gld_ck, "variant": var_ck},
                        "checksum_equal": {
                            "all_three": all_three,
                            "variant_vs_golden": var_vs_gld,
                            "variant_vs_raw": var_vs_raw,
                            "raw_vs_golden": raw_vs_gld,
                        },
                    },
                    "mca": mca,
                    "codesize": {"baseline": bsz, "variant": vsz, "ratio": size_ratio},
                    "unparseable": bool(ierr),
                    "latency_ms": meta_out.get("latency_ms"),
                    "tokens": {"in": meta_out.get("prompt_tokens"), "out": meta_out.get("out_tokens")},
                    "raw_meta": meta_out,
                }
                shots_by_mode[_m].append(shot)

                if report_mode():
                    if not accepted:
                        write_text(_sp["dir"]/"reject.txt", f"reason={reject_reason}\nparse_error={ierr}\n")
                    if pipeline_ir_by_mode.get(_m): write_text(_sp["variant"], pipeline_ir_by_mode[_m])
                    log = []
                    # Build logs
                    log.append("## Build raw")
                    log.append(fmt_cmd_block(build_raw.get("cmd"), build_raw.get("rc"), build_raw.get("out",""), build_raw.get("err","")))
                    if build_golden:
                        log.append("## Build golden")
                        log.append(fmt_cmd_block(build_golden.get("cmd"), build_golden.get("rc"), build_golden.get("out",""), build_golden.get("err","")))
                    if build_var:
                        log.append("## Build variant")
                        log.append(fmt_cmd_block(build_var.get("cmd"), build_var.get("rc"), build_var.get("out",""), build_var.get("err","")))
                    # Verify / Alive2
                    log.append("## Verify")
                    log.append(fmt_cmd_block((v_res.get("cmd") or "").split(), v_res.get("rc"), v_res.get("stdout",""), v_res.get("stderr","")))
                    if isinstance(ae_log, dict) and ae_log.get("cmd"):
                        log.append("## Alive2")
                        log.append(fmt_cmd_block(ae_log.get("cmd"), ae_log.get("exit"), ae_log.get("out",""), ae_log.get("err","")))
                    # Run logs with timing
                    if run_raw:
                        log.append(f"## Run raw  {'sys_time_ms' if is_t005 else 'bench_time_ms'}={(run_raw.get('time_ms') if is_t005 else raw_bench)}")
                        log.append(fmt_cmd_block(run_raw.get("cmd"), run_raw.get("rc"), run_raw.get("out",""), run_raw.get("err","")))
                    if run_gold and (build_golden is not None):
                        log.append(f"## Run golden  {'sys_time_ms' if is_t005 else 'bench_time_ms'}={(run_gold.get('time_ms') if is_t005 else gld_bench)}")
                        log.append(fmt_cmd_block(run_gold.get("cmd"), run_gold.get("rc"), run_gold.get("out",""), run_gold.get("err","")))
                    if run_var:
                        log.append(f"## Run variant  {'sys_time_ms' if is_t005 else 'bench_time_ms'}={(run_var.get('time_ms') if is_t005 else var_bench)}")
                        log.append(fmt_cmd_block(run_var.get("cmd"), run_var.get("rc"), run_var.get("out",""), run_var.get("err","")))
                    # T005 expected output check
                    if is_t005 and expected_out and expected_out.exists():
                        try:
                            exp_txt = expected_out.read_text(encoding="utf-8")
                        except Exception:
                            exp_txt = ""
                        ok_exp = _checksumeq(run_var.get("out",""), exp_txt) if run_var else False
                        log.append(f"## T005 expected\nexpected_file={expected_out} match={ok_exp}")
                    write_text(_sp["shotlog"], "\n".join(log) + "\n")
                    write_repro_sh(_sp, "transform", cfg, func=(func_name if mode=="function" else None))

                any_equiv_true = any_equiv_true or (shot["equiv"].get("ok") is True)

            # Early stop on established equivalence in any mode (including T004_Module)
            if any_equiv_true:
                from ..utils.runner_common import mark_early_stop
                for _m in run_modes:
                    mark_early_stop(case_dirs[_m], k, "equiv_true")
                break

        # write per-case metrics.json for each mode
        for _m in run_modes:
            _mdir = shot_paths(case_dirs[_m], 1)["dir"]; _mdir.mkdir(parents=True, exist_ok=True)
            _shots = shots_by_mode[_m]
            write_json(_mdir/"metrics.json", {
                "case_id": case.id,
                "difficulty": (case.meta.get("difficulty") if isinstance(case.meta, dict) else None),
                "mode": _m,
                "shots": _shots,
                "summary": _summary_flags(_shots),
            })
            def _flag(shots, key, k): 
                return "T" if atk(shots, k, key) else "F"
            print(f"{case.id}({_m}): Pass@1={_flag(_shots,'pass',1)} Pass@5={_flag(_shots,'pass',5)} | "
                  f"Valid@1={_flag(_shots,'valid',1)} Valid@5={_flag(_shots,'valid',5)} | "
                  f"Equiv@1={_flag(_shots,'equiv',1)} Equiv@5={_flag(_shots,'equiv',5)}")
