"""Evaluation utilities for CIRBench.

Provides helpers to compute code size, run llvm-mca, measure runtime via
harness scripts, check Alive2 equivalence, and perform assemble/verify checks.

Notes:
    Keep comments concise; prefer *why* over *what*. Do not change logic here
    without updating downstream callers.
"""
from __future__ import annotations
from pathlib import Path
import subprocess, shlex, tempfile, os, json, re
from pathlib import Path
from ..cfg import CIRBenchConfig

def code_size_bytes(ir: str, cfg: CIRBenchConfig) -> int:
    """Return object size in bytes, falling back to IR byte length."""
    if not (cfg.toolchain.clang and cfg.toolchain.llvm_size):
        return len(ir.encode("utf-8"))
    with tempfile.TemporaryDirectory() as td:
        td = Path(td)
        (td/"in.ll").write_text(ir, encoding="utf-8")
        obj = td/"a.o"
        cmd = f'{shlex.quote(cfg.toolchain.clang)} -c -x ir - -o {shlex.quote(str(obj))}'
        p = subprocess.run(cmd, input=ir.encode("utf-8"), shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        if p.returncode!=0: return len(ir.encode("utf-8"))
        q = subprocess.run([cfg.toolchain.llvm_size, "-B", str(obj)], stdout=subprocess.PIPE)
        try:
            out = q.stdout.decode()
            return obj.stat().st_size
        except Exception:
            return obj.stat().st_size

def llvm_mca_summary(asm_path: Path, cfg: CIRBenchConfig) -> dict:
    """Run llvm-mca on an assembly file and extract additional metrics."""
    if not cfg.toolchain.llvm_mca or not asm_path.exists():
        return {"status": "skip"}

    p = subprocess.run([cfg.toolchain.llvm_mca, "--all-stats", str(asm_path)],
                       stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    txt = p.stdout.decode(errors="ignore")
    import re

    iterations = re.search(r"Iterations:\s+([0-9\.]+)", txt)
    cycles = re.search(r"Total Cycles:\s+([0-9\.]+)", txt)
    instructions = re.search(r"Instructions:\s+([0-9\.]+)", txt)
    uops = re.search(r"Total uOps:\s+([0-9\.]+)", txt)
    ipc = re.search(r"IPC:\s+([0-9\.]+)", txt)
    block_rthroughput = re.search(r"Block RThroughput:\s+([0-9\.]+)", txt)

    def _num(m):
        if not m:
            return None
        try:
            v = float(m.group(1))
            return int(v) if v.is_integer() else v
        except Exception:
            return None

    return {
        "status": "ok" if cycles else "skip",
        "iterations": _num(iterations),
        "cycles": _num(cycles),
        "instructions": _num(instructions),
        "uops": _num(uops),
        "ipc": _num(ipc),
        "block_rthroughput": _num(block_rthroughput),
    }
def runtime_harness(hdir: Path, base_ir: Path, variant_ir: Path) -> dict:
    """Execute a runtime harness script and parse `time_ms=` lines for speedup."""
    try:
        if not hdir:
            return {"status": "skip"}
        if hdir.exists() and hdir.is_dir():
            script = hdir / "run.sh"
            if not script.exists():
                return {"status": "skip", "reason": "no_run_sh"}
            cmd = [str(script)]
            cwd = str(hdir)
        elif hdir.exists() and hdir.is_file():
            cmd = [str(hdir)]
            cwd = str(hdir.parent)
        else:
            return {"status": "skip"}

        p = subprocess.run(cmd, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        out = p.stdout.decode(errors="ignore")
        err = p.stderr.decode(errors="ignore")
        if p.returncode != 0:
            return {"status": "err", "ret": p.returncode, "stderr": err[:200]}

        times = re.findall(r"time_ms=([0-9]+)", out)
        if len(times) >= 2:
            b, v = int(times[0]), int(times[1])
            o3 = int(times[2]) if len(times) >= 3 else None
            res = {"status": "ok", "baseline_ms": b, "variant_ms": v, "speedup": b / v if (b>0 and v>0) else None}
            if o3 and o3>0 and v>0:
                res["o3_ms"] = o3
                res["speedup_vs_o3"] = o3 / v
            return res
        return {"status": "err", "stdout": out[:200]}
    except Exception as ex:
        return {"status": "err", "exc": type(ex).__name__, "msg": str(ex)[:200]}

def alive_equiv(ir_a: str, ir_b: str, cfg: CIRBenchConfig, func: str | None = None, dump_path: Path | None = None) -> dict:
    """Check equivalence with Alive2 (`alive-tv`) and return a structured summary.

    Improvements over previous version:
      - Accepts optional `func` to specify the target function to compare.
      - Includes only helper function bodies that are reachable (transitively) from src/tgt via calls;
        `@atoi` or other internal helpers are available to Alive2 (no Undef) if called.
      - If `func` is None, try 'kernel_run' then 'main'; as a last resort pick
        the largest function in A that also exists in B; if still missing, err.
      - Accepts optional `dump_path` to persist the merged module that is
        passed to Alive2; when provided, the file is written there for inspection.
    """
    import re, tempfile, subprocess
    from pathlib import Path

    def _ret(status: str, equiv=None, cmd=None, exit_code=None, out:str="", tail:str="", err:str="", reason:str="unknown", reason_line:str="", merged_path: str | None = None):
        return {"status":status, "equiv":equiv, "cmd":cmd, "exit":exit_code, "out":out, "tail":tail, "err":err,
                "reason": reason, "reason_line": reason_line, "merged_path": merged_path}

    if not cfg.toolchain.alive_tv:
        return _ret("skip", equiv=None, err="alive-tv not configured", reason="skip")

    # ---- Strip non-semantic module noise but keep attributes table ----
    def _strip(txt: str) -> str:
        lines_out = []
        for line in txt.splitlines():
            s = line.strip()
            if s.startswith("target datalayout") or s.startswith("target triple"):
                continue
            if s.startswith("!llvm.") or (s.startswith("!") and not s.startswith("!0") and not s.startswith("!1")):
                continue
            line = re.sub(r',?\s*![A-Za-z0-9_.-]+\s*!?[0-9A-Za-z_.-]+', '', line)
            lines_out.append(line)
        return "\n".join(lines_out)

    A1 = _strip(ir_a); B1 = _strip(ir_b)

    # ---- Collect typedefs, declares, attributes, globals (union) ----
    def _collect_ordered(pattern, texts):
        seen = set(); out = []
        for txt in texts:
            for m in re.finditer(pattern, txt, flags=re.M):
                line = m.group(0); key = line.strip()
                if key in seen: continue
                seen.add(key); out.append(line)
        return "\n".join(out)

    TYPES_RE = r'^\s*%(?:"[^"]+"|[A-Za-z0-9_.$-]+)\s*=\s*type\b.*$'
    DECL_RE  = r'^\s*declare\s+.+$'
    ATTR_RE  = r'^\s*attributes\s+#\d+\s*=\s*\{.*$'

    typedefs        = _collect_ordered(TYPES_RE, (A1, B1))
    declares        = _collect_ordered(DECL_RE,  (A1, B1))
    attributes_tbl  = _collect_ordered(ATTR_RE,  (A1, B1))

    def _collect_globals_and_constants(texts):
        seen=set(); out=[]
        for txt in texts:
            for line in txt.splitlines():
                s=line.strip()
                if not s or s.startswith(";"): continue
                if s.startswith("@") and "=" in s:
                    key=s
                    if key in seen: continue
                    seen.add(key); out.append(line)
        return "\n".join(out)
    decl_globals = _collect_globals_and_constants((A1,B1))

    # ---- Parse ALL function bodies (name -> define...{...}) ----
    def _parse_defines(txt: str) -> dict:
        res = {}
        for m in re.finditer(r'(?m)^\s*define\s+[^@]*@([A-Za-z0-9_.$-]+)\s*\(', txt):
            start=m.start(); name=m.group(1)
            brace=txt.find('{', start)
            if brace==-1: continue
            depth=0; i=brace; end=-1
            while i<len(txt):
                ch=txt[i]
                if ch=='{': depth+=1
                elif ch=='}':
                    depth-=1
                    if depth==0: end=i; break
                i+=1
            if end==-1: continue
            body=txt[start:end+1]
            res[name]=body
        return res

    defsA = _parse_defines(A1)
    defsB = _parse_defines(B1)

    def _direct_callees(body: str) -> set[str]:
        """
        Extract direct callee names from a function body by scanning call/invoke lines.
        Supports @name and @"name" forms. Returns a set of symbol names without quotes.
        """
        names = set()
        for line in body.splitlines():
            ls = line.lstrip()
            if not (ls.startswith("call") or ls.startswith("invoke")):
                continue
            # collect all @symbol occurrences on this line
            for m in re.finditer(r'@(?:"([^"]+)"|([A-Za-z0-9_.$-]+))', line):
                sym = m.group(1) or m.group(2)
                if sym:
                    names.add(sym)
        return names

    def _reachable_helpers(start_funcs: set[str], defsA: dict, defsB: dict, exclude: set[str]) -> list[str]:
        """
        Compute the transitive closure of helper functions reachable from `start_funcs`.
        Preference when both A and B define the same helper: choose B (tgt) definition.
        Returns a list of function bodies (strings) in a stable, sorted name order.
        """
        visited = set()
        queue = [n for n in sorted(start_funcs) if n not in exclude]
        bodies = []
        while queue:
            name = queue.pop(0)
            if name in visited or name in exclude:
                continue
            body = defsB.get(name) or defsA.get(name)
            if not body:
                # no definition present (likely an external decl) — skip body but continue
                visited.add(name)
                continue
            bodies.append(body)
            visited.add(name)
            # add this helper's callees
            for cal in sorted(_direct_callees(body)):
                if cal not in visited and cal not in exclude:
                    queue.append(cal)
        # sort bodies by their function name to keep deterministic order
        def _fname(b: str) -> str:
            m = re.search(r'(?m)^\s*define\s+[^@]*@([A-Za-z0-9_.$-]+)\s*\(', b)
            return m.group(1) if m else ""
        bodies.sort(key=_fname)
        return bodies

    # ---- Choose target function ----
    tgt_name = func
    if not tgt_name:
        for cand in ("kernel_run","main"):
            if cand in defsA or cand in defsB:
                tgt_name = cand; break
    if not tgt_name:
        common = set(defsA.keys()) & set(defsB.keys())
        if common:
            tgt_name = max(common, key=lambda n: len(defsA.get(n,"")))
        elif defsA:
            tgt_name = max(defsA.keys(), key=lambda n: len(defsA[n]))
    if not tgt_name or tgt_name not in defsA or tgt_name not in defsB:
        return _ret("err", equiv=None, err="missing_target_function", reason="parse_error", reason_line=f"target='{tgt_name}' not in both modules")

    a_body = defsA[tgt_name]; b_body = defsB[tgt_name]

    def _rename_define(body: str, new_name: str) -> str:
        return re.sub(r'(^\s*define\s+[^@]*@)[A-Za-z0-9_.$-]+', r'\1'+new_name, body, count=1, flags=re.M)

    def _rewrite_self_refs(body: str, orig: str, new_sym: str) -> str:
        """
        Replace any intra-body reference to the original function name with the new symbol.
        Handles both @name and @"name" forms and ensures we don't match longer symbol suffixes.
        """
        pat = re.compile(r'@(?:"' + re.escape(orig) + r'"|' + re.escape(orig) + r')(?=[^A-Za-z0-9_.$-]|$)')
        return re.sub(pat, '@' + new_sym, body)

    a_ren = _rewrite_self_refs(_rename_define(a_body, "src"), tgt_name, "src")
    b_ren = _rewrite_self_refs(_rename_define(b_body, "tgt"), tgt_name, "tgt")

    # ---- Select ONLY helpers reachable from src/tgt (transitive closure) ----
    # Start from direct callees of the original target bodies in A/B.
    start_callees = set()
    start_callees.update(_direct_callees(a_body))
    start_callees.update(_direct_callees(b_body))
    # Exclude the target itself — it's renamed to src/tgt separately.
    helpers = _reachable_helpers(start_callees, defsA, defsB, exclude={tgt_name})

    # ---- Compose final module for Alive2 ----
    module_txt = "\n\n".join(x for x in [
        typedefs, decl_globals, declares, *(helpers or []), a_ren, b_ren, attributes_tbl
    ] if x and x.strip())

    # ---- Decide output path: dump to `dump_path` if provided, else temp ----
    merged_path: Path | None = None
    if dump_path is not None:
        dump_path.parent.mkdir(parents=True, exist_ok=True)
        dump_path.write_text(module_txt, encoding="utf-8")
        merged_path = dump_path
        f = dump_path
        temp_ctx = None
    else:
        temp_ctx = tempfile.TemporaryDirectory()
        td = Path(temp_ctx.name)
        f = td / "input.ll"
        f.write_text(module_txt, encoding="utf-8")
        merged_path = f

    # ---- Run alive-tv ----
    try:
        cmd = [cfg.toolchain.alive_tv, str(f), "-src-fn=src", "-tgt-fn=tgt", "--quiet"]
        p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        out = (p.stdout.decode(errors="ignore") + p.stderr.decode(errors="ignore"))
        tail = "\n".join(out.splitlines()[-8:])

        # ---- Classify reason ----
        def _classify(out_text: str) -> tuple[str, bool|None, str]:
            lo = out_text.lower()
            lines = [l for l in out_text.splitlines() if l.strip()]
            if lines:
                for l in reversed(lines[-12:]):
                    ll=l.lower()
                    if "seems to be correct" in ll: return ("equiv_true", True, l)
                    if any(k in ll for k in ["doesn't verify","mismatch","counterexample","prove failed"]):
                        return ("equiv_false", False, l)
                    if "timeout" in ll: return ("timeout", None, l)
                    if "out of memory" in ll: return ("oom", None, l)
                    if "max. memory exceeded" in ll or "max memory exceeded" in ll:
                        return ("memory_exceeded", None, l)
                    if "increasing the unroll factor" in ll:
                        return ("unroll_hint", None, l)
                    if "could not read bitcode" in ll or "error:" in ll:
                        return ("parse_error", None, l)
            if "seems to be correct" in lo: return ("equiv_true", True, "seems to be correct")
            if any(k in lo for k in ["doesn't verify","mismatch","counterexample","prove failed"]):
                return ("equiv_false", False, "")
            if "timeout" in lo: return ("timeout", None, "")
            if "out of memory" in lo: return ("oom", None, "")
            if "max. memory exceeded" in lo or "max memory exceeded" in lo:
                return ("memory_exceeded", None, "")
            if "increasing the unroll factor" in lo:
                return ("unroll_hint", None, "")
            if "could not read bitcode" in lo or "error:" in lo:
                return ("parse_error", None, "")
            return ("unknown", None, tail)

        reason, equiv, reason_line = _classify(out)

        if reason in ("equiv_true", "equiv_false"):
            status = "ok"
        elif reason in ("timeout", "oom", "memory_exceeded"):
            status = "timeout"
        elif reason in ("parse_error", "exception"):
            status = "err"
        else:
            # Unknown / unclear signal -> treat as error to force checksum fallback upstream
            status = "err"

        return _ret(status, equiv=equiv, cmd=cmd, exit_code=p.returncode, out=out, tail=tail,
                    reason=reason, reason_line=reason_line, merged_path=str(merged_path))
    except Exception as ex:
        return _ret("err", equiv=None, err=f"{type(ex).__name__}: {str(ex)[:200]}", reason="exception", merged_path=str(merged_path) if merged_path else None)
    finally:
        # Cleanup temp dir if we created one
        try:
            if temp_ctx is not None:
                temp_ctx.cleanup()
        except Exception:
            pass

def assemble_ok(ir_text: str, cfg: CIRBenchConfig) -> dict:
    """Assemble IR with `llvm-as`; success if rc==0 or output contains 'assembly parsed'.

    Returns:
        {"ok": bool, "stdout": str, "stderr": str, "rc": int, "cmd": str}
    """
    # Conservative fallback when llvm-as is unavailable (non-strict)
    if not getattr(cfg.toolchain, "llvm_as", None):
        ok = bool(ir_text and "define" in ir_text and "ret" in ir_text)
        return {"ok": ok, "stdout": "", "stderr": "[fallback: no llvm-as]", "rc": 0 if ok else 1, "cmd": ""}
    with tempfile.TemporaryDirectory() as td:
        td = Path(td)
        f = td/"v.ll"
        f.write_text(ir_text, encoding="utf-8")
        cmd = [cfg.toolchain.llvm_as, str(f), "-o", str(td/"a.bc")]
        p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        out = p.stdout.decode(errors="ignore")
        err = p.stderr.decode(errors="ignore")
        txt = (out + "\n" + err).lower()
        ok = (p.returncode == 0) or ("assembly parsed" in txt)
        return {"ok": ok, "stdout": out, "stderr": err, "rc": p.returncode, "cmd": " ".join(shlex.quote(x) for x in cmd)}

def verify_ok(ir_text: str, cfg: CIRBenchConfig) -> dict:
    """Run `opt -passes=verify`; success if rc==0.

    Returns:
        {"ok": bool, "stdout": str, "stderr": str, "rc": int, "cmd": str}
    """
    # Fallback when opt is unavailable: reuse assemble result; no semantic check
    if not getattr(cfg.toolchain, "opt", None):
        a = assemble_ok(ir_text, cfg)
        return {"ok": bool(a.get("ok")), "stdout": "", "stderr": "[fallback: no opt]", "rc": 0 if a.get("ok") else 1, "cmd": ""}
    with tempfile.TemporaryDirectory() as td:
        td = Path(td)
        f = td/"v.ll"
        f.write_text(ir_text, encoding="utf-8")
        cmd = [cfg.toolchain.opt, "-passes=verify", str(f), "-disable-output"]
        p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        out = p.stdout.decode(errors="ignore")
        err = p.stderr.decode(errors="ignore")
        ok = (p.returncode == 0)
        return {"ok": ok, "stdout": out, "stderr": err, "rc": p.returncode, "cmd": " ".join(shlex.quote(x) for x in cmd)}
