#!/usr/bin/env python3
"""
Self‑contained Lean‑mathlib compilation harness for the Herald dataset.

• Creates/refreshes a project `TmpProj/`
• Reuses the same project to recompile many candidate theorems quickly
"""

import subprocess, os, shutil, sys
from pathlib import Path
import textwrap
import re
import uuid
#from git import Repo
#from lean_dojo import LeanGitRepo, trace
import pandas as pd
import json

# ----------------- CONFIG -----------------
TIMEOUT_SECS        = 300
PROJECT_NAME        = "TmpProjDir"   # also the Lean namespace
LEAN_VERSION        = "v4.15.0"
REPL_VERSION        = "21966799da3691a0912b5a15193585bd2dd7165d" #https://github.com/leanprover-community/repl/tree/21966799da3691a0912b5a15193585bd2dd7165d (for v.4.15.0)
# Set to True if you want a completely fresh project each run
FORCE_CLEAN_SETUP   = False
rows_writeCSV       = []
# ------------------------------------------

# Base tmp folder
TMP_BASE = Path("tmpFolder")
# Unique run folder
if not FORCE_CLEAN_SETUP:
    RUN_UUID = "8d9a729d-dec6-452d-bb90-d63be139ee52"
else:
    RUN_UUID = str(uuid.uuid4())
RUN_DIR  = TMP_BASE / RUN_UUID

PROJECT_DIR  = RUN_DIR / PROJECT_NAME
SRC_DIR      = PROJECT_DIR / PROJECT_NAME
LEAN_SRC     = SRC_DIR / "Basic.lean"

# ---------- helpers ----------
def _preexec_set_mem_limit(mem_limit_bytes):
    """Return a preexec_fn that caps the child's RLIMIT_AS (virtual memory).

    When the Lean subprocess tries to allocate beyond this, malloc fails and
    Lean crashes gracefully inside its own process rather than the Linux OOM
    killer taking down the whole pod (parent + child).
    """
    import resource
    def _preexec():
        try:
            resource.setrlimit(resource.RLIMIT_AS, (mem_limit_bytes, mem_limit_bytes))
        except Exception:
            pass
    return _preexec


def run_cmd(cmd, cwd=".", timeout=None, mem_limit_bytes=None):
    """Run a shell command with optional wall-clock + address-space limits.

    Returns ``(ok: bool, combined_output: str)``. On timeout or memory
    exhaustion the subprocess is killed and a clear error string is returned
    rather than propagating an exception — the caller should treat this as a
    failed Lean check, not an evaluation crash.
    """
    try:
        preexec_fn = _preexec_set_mem_limit(mem_limit_bytes) if mem_limit_bytes else None
        completed = subprocess.run(
            cmd, cwd=cwd, text=True,
            capture_output=True,
            timeout=timeout,
            preexec_fn=preexec_fn,
        )
        out = (completed.stdout or "") + (completed.stderr or "")
        # Negative return codes on Linux: -9 = SIGKILL (often OOM), -11 = SIGSEGV.
        if completed.returncode < 0:
            return False, f"[subprocess killed by signal {-completed.returncode}]\n{out}"
        return completed.returncode == 0, out
    except subprocess.TimeoutExpired as exc:
        # TimeoutExpired's stdout/stderr may be raw bytes even with text=True;
        # defensively coerce to str to avoid "can only concatenate str to str" TypeError.
        def _as_str(x):
            if x is None: return ""
            if isinstance(x, str): return x
            try: return x.decode("utf-8", errors="replace")
            except Exception: return str(x)
        return False, f"[subprocess timed out after {timeout}s]\n{_as_str(exc.stdout) + _as_str(exc.stderr)}"
    except MemoryError:
        return False, "[subprocess raised MemoryError]"
    except Exception as exc:
        return False, f"[subprocess error] {exc}"

def bootstrap_project():
    """Create project (or refresh if FORCE_CLEAN_SETUP)."""
    if FORCE_CLEAN_SETUP and RUN_DIR.exists():
        print("**  Removing old project...")
        shutil.rmtree(RUN_DIR)

    RUN_DIR.mkdir(parents=True, exist_ok=True)

    if not PROJECT_DIR.exists():
        print("** Initialising empty Lean project in", PROJECT_DIR)
        ok, out = run_cmd(
            ["lake", "new", PROJECT_NAME],
            cwd=RUN_DIR
        )
        if not ok:
            sys.exit(f"X  lake new failed:\n{out}")

    # Content you want to write
    content = f"""
import Lake
open Lake DSL

require mathlib from git
  "https://github.com/leanprover-community/mathlib4" @ "{LEAN_VERSION}"

require REPL from git
  "https://github.com/leanprover-community/repl.git" @ "{REPL_VERSION}"

package «{PROJECT_NAME}» where
-- add package configuration options here

@[default_target]
lean_lib «{PROJECT_NAME}» where
-- add library configuration options here"""

    # Open the file in write mode and write the content
    with open(PROJECT_DIR / "lakefile.lean", "w") as f:
        f.write(content)

    # Ensure dependencies & cache
    for cmd in (["lake", "update"], ["lake", "exe", "cache", "get"]):
        ok, out = run_cmd(cmd, cwd=PROJECT_DIR)
        if not ok:
            sys.exit(f"X  {' '.join(cmd)} failed:\n{out}")

    with open(PROJECT_DIR / "lean-toolchain", "r") as f:
        toolchainVersion = f.read().strip()

    try:
        expectedToolchainVer = f"leanprover/lean4:{LEAN_VERSION}"
        assert toolchainVersion == expectedToolchainVer
        print (f"✓ Lean version verified to be {expectedToolchainVer}")
    except:
        sys.exit(f"✗ Unexpected version: {toolchainVersion}...should be {expectedToolchainVer}")
    return RUN_DIR, PROJECT_DIR, LEAN_SRC #the unique run directory, the project directory, and the lean file path

# Markdown language tags that the inference extractor may have left at the
# start of LLM_Output (Strategy 2 regex was `r"```\s*(.*?)```"` which captured
# the language identifier as line-1 content).
_LEADING_LANG_TAGS = {
    "tactics", "lean", "lean4", "leanprover", "mathlib",
    "math", "haskell", "ocaml", "isabelle", "coq",
}


def sanitize_lean(code: str) -> str:
    """Idempotent safety-net for the leading-language-tag failure mode.

    Bug 1: the inference extractor's Strategy-2 regex ``r"```\\s*(.*?)```"``
    did not consume the optional markdown language identifier, so outputs
    that begin with a fenced ``\\`\\`\\`tactics`` block ended up with the
    literal token ``tactics`` as the first line of LLM_Output. Lean rejects
    this immediately as ``unexpected identifier; expected command``. Strip it.

    Earlier versions of this function also greedy-stripped block comments
    in the preamble, on the theory that ``/- /-- ... -/`` patterns confuse
    Lean's lexer. Empirically Lean handles this correctly (``/--`` inside
    an active ``/-`` block is treated as content, not a nested opener), and
    stripping the comment introduces line/column shifts that break otherwise-
    compiling proofs. So this function now only handles the language-tag bug.

    No-op for clean inputs.
    """
    if not code:
        return code
    s = code.lstrip("﻿")  # strip BOM only — preserve line structure for Lean

    # Detect & peel leading language-tag lines. Be conservative: only strip
    # if the first non-blank line is exactly a known tag, then re-strip any
    # blank lines that immediately follow.
    while True:
        # Find first non-empty line
        rest = s.lstrip("\n\r\t ")
        if rest is s and not s:
            break
        nl = rest.find("\n")
        first = rest[:nl] if nl != -1 else rest
        if first.strip().lower() in _LEADING_LANG_TAGS:
            # Drop everything up to and including this line
            s = rest[nl + 1:] if nl != -1 else ""
        else:
            break

    return s


def write_basic_lean(header, proof, leanFile):
    """Overwrite TmpProj/Basic.lean with header + formal_proof.

    Applies ``sanitize_lean`` to ``proof`` to remove markdown language tags
    and unbalanced preamble comments before passing to the Lean parser.
    """
    SRC_DIR.mkdir(parents=True, exist_ok=True)

    # Sanitize model output before writing — always-on safety net.
    proof = sanitize_lean(proof or "")

    # Separate import lines, and everything else (including open and variable)
    import_lines = []
    other_header_lines = []
    for line in (header or "").splitlines():
        if re.match(r'^\s*import\s+', line):
            import_lines.append(line.strip())
        else:
            other_header_lines.append(line)

    # Wrap everything in a namespace to avoid name clashes
    lean_code = "\n".join([
        *import_lines,
        *other_header_lines,
        proof
    ])
    leanFile.write_text(lean_code)
    return lean_code

def check_compiles(project_dir):
    """Run lake build; return (success, truncated_error_or_None)."""
    ok, out = run_cmd(["lake", "build"], cwd=project_dir)
    return ok, (None if ok else "\n".join(out.splitlines()[:12]))

def check_repl(lean_file_path, project_dir):
    """Run REPL with per-invocation memory + time caps.

    The caps come from the environment so eval jobs can tune them without
    touching code. Defaults keep the Lean REPL well under a typical pod's
    64Gi / 128Gi limit so that a single bad ``nlinarith`` cannot crash the
    whole evaluation pod.
    """
    lean_file_path_relative = lean_file_path.relative_to(project_dir).as_posix()
    cmd = [
        "sh",
        "-c",
        f"echo '{{\"path\": \"{lean_file_path_relative}\", \"allTactics\": true}}' | lake exe repl"
    ]

    mem_gb = float(os.environ.get("LEAN_REPL_MEM_GB", "24"))
    timeout_s = float(os.environ.get("LEAN_REPL_TIMEOUT_S", "300"))
    mem_bytes = int(mem_gb * (1024 ** 3))

    ok_repl, out_repl = run_cmd(
        cmd, cwd=project_dir, timeout=timeout_s, mem_limit_bytes=mem_bytes,
    )

    # Find the first { ... } JSON block in the output
    json_match = re.search(r"\{.*\}", out_repl, re.DOTALL)
    if not json_match:
        return False, "[No JSON found in REPL output]" + out_repl

    json_str = json_match.group()
    try:
        repl_json = json.loads(json_str)
        messages = repl_json.get("messages", [])
        errors = [m for m in messages if m.get("severity") == "error"]
        ok_repl = len(errors) == 0
        return ok_repl, "[JSON parsed correctly]" + str(repl_json)
    except json.JSONDecodeError:
        return False, "[Failed to parse JSON from REPL output]" + out_repl


class PersistentLeanREPL:
    """Long-running ``lake exe repl`` process reused across many file checks.

    The non-persistent ``check_repl`` above starts a fresh ``lake exe repl`` per
    call, paying ~15–20 s of Lean/mathlib startup each time. For evaluation runs
    that check hundreds of files, we keep one REPL alive and stream JSON
    requests to it. Each request is an independent compilation unit (REPL
    ``path`` command), so there is no semantic carry-over between calls.

    Safety design:
      * Hard wall-clock timeout per request — kills + restarts on hang.
      * Address-space limit via ``RLIMIT_AS`` so a bad ``nlinarith`` can't eat
        the pod.
      * Periodic restart every ``restart_every`` requests to bound memory
        accumulation in a single REPL session.
      * Auto-restart if the process dies (BrokenPipeError / EOF).

    This class is NOT thread-safe — use one instance per worker.
    """

    def __init__(self, project_dir, mem_gb=None, timeout_s=None, restart_every=None):
        from pathlib import Path as _Path
        self.project_dir = _Path(project_dir)
        self.mem_gb = mem_gb if mem_gb is not None else float(os.environ.get("LEAN_REPL_MEM_GB", "24"))
        self.timeout_s = timeout_s if timeout_s is not None else float(os.environ.get("LEAN_REPL_TIMEOUT_S", "300"))
        self.restart_every = restart_every if restart_every is not None else int(os.environ.get("LEAN_REPL_RESTART_EVERY", "100"))
        self.all_tactics = os.environ.get("LEAN_REPL_ALL_TACTICS", "true").strip().lower() not in {"0", "false", "no", "off"}
        self.mem_bytes = int(self.mem_gb * (1024 ** 3))
        self.proc = None
        self.count = 0
        self.total = 0   # lifetime request count (across restarts)
        self.restarts = 0
        self._start()

    def _start(self):
        import resource
        import fcntl
        self._kill()

        def _pre():
            try:
                resource.setrlimit(resource.RLIMIT_AS, (self.mem_bytes, self.mem_bytes))
            except Exception:
                pass

        # Binary mode + manual utf-8 decode. Using text=True + readline risks
        # blocking on partial lines; we want os.read chunks we can timeout on.
        self.proc = subprocess.Popen(
            ["lake", "exe", "repl"],
            cwd=str(self.project_dir),
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=False,
            bufsize=0,
            preexec_fn=_pre,
        )
        # Make stdout fd non-blocking so os.read never hangs.
        fd = self.proc.stdout.fileno()
        flags = fcntl.fcntl(fd, fcntl.F_GETFL)
        fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
        self.count = 0
        if self.restarts > 0:
            print(f"[PersistentLeanREPL] restart #{self.restarts} (pid={self.proc.pid})", flush=True)
        else:
            print(f"[PersistentLeanREPL] started (pid={self.proc.pid})", flush=True)

    def _kill(self):
        if self.proc is None:
            return
        try:
            if self.proc.stdin and not self.proc.stdin.closed:
                self.proc.stdin.close()
        except Exception:
            pass
        try:
            self.proc.terminate()
            self.proc.wait(timeout=3)
        except Exception:
            try:
                self.proc.kill()
                self.proc.wait(timeout=3)
            except Exception:
                pass
        self.proc = None

    def check(self, lean_file_path):
        """Check a single Lean file. Returns (ok, details_str) like ``check_repl``."""
        import select
        import time as _time
        from pathlib import Path as _Path

        # Restart proactively if we've hit the per-session cap or the process died.
        if (self.count >= self.restart_every) or (self.proc is None) or (self.proc.poll() is not None):
            self.restarts += 1
            self._start()

        lean_file_path = _Path(lean_file_path)
        rel = lean_file_path.relative_to(self.project_dir).as_posix()
        # REPL separates requests by a blank line — send JSON + "\n\n".
        req = (json.dumps({"path": rel, "allTactics": self.all_tactics}) + "\n\n").encode("utf-8")
        try:
            self.proc.stdin.write(req)
            self.proc.stdin.flush()
        except (BrokenPipeError, OSError) as e:
            self.restarts += 1
            self._start()
            return False, f"[persistent REPL broken pipe: {e}, restarted]"

        deadline = _time.time() + self.timeout_s
        buf = b""
        stdout_fd = self.proc.stdout.fileno()
        while True:
            remaining = deadline - _time.time()
            if remaining <= 0:
                self.restarts += 1
                self._start()
                return False, f"[persistent REPL timed out after {self.timeout_s}s]\n{buf.decode('utf-8', 'replace')[:2000]}"
            r, _, _ = select.select([stdout_fd], [], [], min(remaining, 1.0))
            if not r:
                if self.proc.poll() is not None:
                    self.restarts += 1
                    self._start()
                    return False, f"[persistent REPL died]\n{buf.decode('utf-8', 'replace')[:2000]}"
                continue
            try:
                chunk = os.read(stdout_fd, 65536)
            except BlockingIOError:
                continue
            if chunk == b"":
                self.restarts += 1
                self._start()
                return False, f"[persistent REPL EOF]\n{buf.decode('utf-8', 'replace')[:2000]}"
            buf += chunk
            # Try to extract a complete JSON object from the buffered output.
            buf_str = buf.decode("utf-8", "replace")
            m = re.search(r"\{.*\}", buf_str, re.DOTALL)
            if m:
                try:
                    repl_json = json.loads(m.group())
                except json.JSONDecodeError:
                    continue  # partial — keep reading
                self.count += 1
                self.total += 1
                messages = repl_json.get("messages", [])
                errors = [msg for msg in messages if msg.get("severity") == "error"]
                ok_repl = len(errors) == 0
                return ok_repl, "[JSON parsed correctly]" + str(repl_json)

    def close(self):
        self._kill()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()


def check_repl_persistent(lean_file_path, repl):
    """Convenience wrapper so callers can mirror ``check_repl`` signature."""
    return repl.check(lean_file_path)

# -------------- main --------------
def main():
    run_dir, project_dir, lean_file_path = bootstrap_project()

    for i in range(0, 300001, 1000):
        fname = f"Basic_{i:06d}.lean"
        fpath = SRC_DIR / fname
        fpath.touch(exist_ok=True)   # create empty file (does nothing if it exists)
        print(f"Created: {fpath}")

    '''
    for idx, row in enumerate(ds):
        ok_compile, ok_repl = False, False
        herald_id = str(row["id"])
        # Dedent the header and proof to avoid indentation issues
        proof_header = (row["header"]).strip()
        proof_body = (row["formal_proof"]).strip()
        lean_code = write_basic_lean(proof_header, proof_body, lean_file_path)
        ok_compile, out_compile = check_compiles(project_dir)
        status = "✓" if ok_compile else "✗"
        print("\n\n\n   " + "=" * 80)
        print(f"[{status}] row {idx}")

        # Print the Lean source always
        print("   " + "-" * 40)
        print("   >> Lean source:")
        print("   " + "-" * 40)
        print(textwrap.indent(lean_code.strip(), "   "))
        print("   " + "-" * 40)

        if not ok_compile:
            print("   ‣ error:", out_compile.replace("\n", "\n     "))
            ok_repl, out_repl = False, "Skipped as LEAN compilation failed"
        else:
            ok_repl, out_repl = check_repl(lean_file_path, project_dir)
            if not ok_repl:
                print ("   >> REPL failed!")
            else:
                print ("   >> REPL ran successfully!")
                print("   " + "-" * 40)
                print(textwrap.indent(out_repl.strip(), "   "))
                print("   " + "-" * 40)
                # Save only successful DAG rows
                
        add_csv_row(herald_id, lean_code, str(ok_compile), out_compile, str(ok_repl), out_repl)

        passed_compile += (ok_compile == True)
        passed_dag += (ok_repl == True)

    total = len(ds)
    print(f"\n*** Compiled {passed_compile}/{total} successfully.")
    print(f"\n*** Computed DAG for {passed_dag}/{total} successfully.")

    # Construct DataFrame and save as CSV
    df = pd.DataFrame(rows_writeCSV)

    csv_path = "herald_lean_dag.csv"
    df.to_csv(csv_path, encoding="utf-8", index=False)
    print(f"Saved results to {csv_path}")
    '''

if __name__ == "__main__":
    main()
