from __future__ import annotations

import argparse
import json
import os
import random
import re
import subprocess
import tempfile
import time
from collections import Counter
from pathlib import Path
from typing import Any

try:
    from execution_audit_common import classify_execution_error, is_executed_class, status_for_error_class
    from proofstate_common import load_jsonl, normalize_entry, prepare_representation_rows, theorem_split
    from run_search import candidate_rankings, parse_strategies
except ImportError:
    from scripts.execution_audit_common import classify_execution_error, is_executed_class, status_for_error_class
    from scripts.proofstate_common import load_jsonl, normalize_entry, prepare_representation_rows, theorem_split
    from scripts.run_search import candidate_rankings, parse_strategies


def module_from_file(file_name: str) -> str:
    if file_name.startswith("MathlibSubset/Mathlib/"):
        file_name = file_name.removeprefix("MathlibSubset/")
    return file_name.removesuffix(".lean").replace("/", ".")


def namespace_from_theorem(theorem: str) -> str:
    parts = theorem.split(".")
    return ".".join(parts[:-1]) if len(parts) > 1 else ""


def binder_from_context(line: str) -> str | None:
    text = " ".join(str(line).strip().split())
    if not text or text.startswith("⊢") or " : " not in text:
        return None
    names, typ = text.split(" : ", 1)
    if " := " in typ:
        return None
    if any(token in names for token in ["✝", "✝", "†"]):
        return None
    if names.startswith("[") and names.endswith("]"):
        return names
    if names.startswith("inst") or names.startswith("_inst"):
        return f"[{names} : {typ}]"
    return f"({names} : {typ})"


def lean_preamble(imports: str) -> list[str]:
    return [
        f"import {imports}",
        "",
        "set_option maxHeartbeats 400000",
        "set_option synthInstance.maxHeartbeats 80000",
        "set_option linter.unusedTactic false",
        "",
    ]


def render_candidate_block(row: dict[str, Any], tactic: str, marker: str = "") -> list[str]:
    namespace = namespace_from_theorem(row["theorem"])
    binders = [binder for item in row.get("local_context", []) if (binder := binder_from_context(item))]
    binder_text = " ".join(binders)
    goal = str(row.get("main_goal", "")).strip()
    tactic_lines = "\n".join(f"  {line}" if line.strip() else "" for line in str(tactic).splitlines())
    lines = [f"/- CANDIDATE {marker} -/"] if marker else []
    if namespace:
        lines += [f"namespace {namespace}", ""]
    lines += [
        f"example {binder_text} : {goal} := by",
        tactic_lines or "  skip",
        "  try all_goals sorry",
        "",
    ]
    if namespace:
        lines += [f"end {namespace}", ""]
    return lines


def render_candidate_file(row: dict[str, Any], tactic: str) -> str:
    imports = module_from_file(row["file"])
    return "\n".join(lean_preamble(imports) + render_candidate_block(row, tactic))


def lean_invocation(workdir: Path, direct_lean: bool) -> tuple[list[str], dict[str, str] | None]:
    if not direct_lean:
        return ["lake", "env", "lean"], None
    env = os.environ.copy()
    lean_paths = [workdir / ".lake" / "build" / "lib" / "lean"]
    lean_paths.extend(sorted((workdir / ".lake" / "packages").glob("*/.lake/build/lib/lean")))
    existing = env.get("LEAN_PATH")
    if existing:
        lean_paths.extend(Path(item) for item in existing.split(os.pathsep) if item)
    env["LEAN_PATH"] = os.pathsep.join(str(path) for path in lean_paths)
    return ["lean"], env


def lean_accepts(
    row: dict[str, Any],
    tactic: str,
    workdir: Path,
    timeout: int,
    direct_lean: bool = False,
) -> dict[str, Any]:
    source = render_candidate_file(row, tactic)
    with tempfile.NamedTemporaryFile("w", suffix=".lean", dir=workdir, encoding="utf-8", delete=False) as f:
        f.write(source)
        tmp_path = Path(f.name)
    try:
        cmd, env = lean_invocation(workdir, direct_lean)
        proc = subprocess.run(
            cmd + [str(tmp_path)],
            cwd=workdir,
            env=env,
            text=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            timeout=timeout,
        )
        accepted = proc.returncode == 0
        output = "\n".join(part for part in [proc.stdout, proc.stderr] if part)
        return {
            "status": "executed",
            "accepted": accepted,
            "error": "" if accepted else output[-4000:],
            "error_class": classify_execution_error("executed", accepted, output),
        }
    except subprocess.TimeoutExpired as exc:
        return {"status": "timeout", "accepted": False, "error": str(exc), "error_class": "timeout"}
    finally:
        tmp_path.unlink(missing_ok=True)


def batched(items: list[Any], batch_size: int) -> list[list[Any]]:
    if batch_size <= 0:
        return [items]
    return [items[i : i + batch_size] for i in range(0, len(items), batch_size)]


def lean_error_messages(output: str) -> list[tuple[int, str]]:
    header = re.compile(r":(\d+):\d+: (error|warning|info):")
    messages: list[tuple[int, str]] = []
    current_line: int | None = None
    current: list[str] = []
    for line in output.splitlines():
        match = header.search(line)
        if match:
            if current_line is not None:
                messages.append((current_line, "\n".join(current)))
            if match.group(2) == "error":
                current_line = int(match.group(1))
                current = [line]
            else:
                current_line = None
                current = []
        elif current_line is not None:
            current.append(line)
    if current_line is not None:
        messages.append((current_line, "\n".join(current)))
    return messages


def lean_accepts_batch(
    pending: list[tuple[dict[str, Any], dict[str, Any]]],
    workdir: Path,
    timeout: int,
    direct_lean: bool,
) -> None:
    if not pending:
        return
    module = module_from_file(pending[0][1]["file"])
    lines = lean_preamble(module)
    spans: list[tuple[dict[str, Any], int, int]] = []
    for index, (record, row) in enumerate(pending):
        start = len(lines) + 1
        lines.extend(render_candidate_block(row, record["candidate_tactic"], marker=str(index)))
        spans.append((record, start, len(lines)))
    source = "\n".join(lines)
    with tempfile.NamedTemporaryFile("w", suffix=".lean", dir=workdir, encoding="utf-8", delete=False) as f:
        f.write(source)
        tmp_path = Path(f.name)
    try:
        cmd, env = lean_invocation(workdir, direct_lean)
        proc = subprocess.run(
            cmd + [str(tmp_path)],
            cwd=workdir,
            env=env,
            text=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            timeout=timeout,
        )
        output = "\n".join(part for part in [proc.stdout, proc.stderr] if part)
        messages = lean_error_messages(output)
        span_errors: dict[int, list[str]] = {id(record): [] for record, _, _ in spans}
        global_errors: list[str] = []
        for line_no, message in messages:
            matched = False
            for record, start, end in spans:
                if start <= line_no <= end:
                    span_errors[id(record)].append(message)
                    matched = True
                    break
            if not matched:
                global_errors.append(message)
        global_failure = bool(global_errors) or (proc.returncode != 0 and not messages)
        for record, _, _ in spans:
            errors = span_errors[id(record)]
            if errors:
                error_text = "\n".join(errors)
                error_class = classify_execution_error("executed", False, error_text)
                record.update(
                    {
                        "status": status_for_error_class(error_class),
                        "accepted": False,
                        "error": error_text[-4000:],
                        "error_class": error_class,
                    }
                )
            elif global_failure:
                error_class = classify_execution_error("environment_error", False, output)
                record.update(
                    {
                        "status": status_for_error_class(error_class, default_status="environment_error"),
                        "accepted": False,
                        "error": output[-4000:],
                        "error_class": error_class,
                    }
                )
            else:
                record.update({"status": "executed", "accepted": True, "error": "", "error_class": "executed_accept"})
    except subprocess.TimeoutExpired as exc:
        for record, _, _ in spans:
            record.update({"status": "timeout", "accepted": False, "error": str(exc), "error_class": "timeout"})
    finally:
        tmp_path.unlink(missing_ok=True)


def execute_pending_batches(
    pending: list[tuple[dict[str, Any], dict[str, Any]]],
    records: list[dict[str, Any]],
    workdir: Path,
    timeout: int,
    batch_size: int,
    direct_lean: bool,
    progress_output: Path | None,
) -> None:
    by_file: dict[str, list[tuple[dict[str, Any], dict[str, Any]]]] = {}
    for record, row in pending:
        by_file.setdefault(row["file"], []).append((record, row))
    all_batches: list[list[tuple[dict[str, Any], dict[str, Any]]]] = []
    for file_name in sorted(by_file):
        all_batches.extend(batched(by_file[file_name], batch_size))
    start_time = time.time()
    write_progress(records, progress_output, len(all_batches), 0, start_time)
    for index, batch in enumerate(all_batches, start=1):
        lean_accepts_batch(batch, workdir, timeout, direct_lean)
        if len(batch) > 1 and any(record.get("status") in {"environment_error", "reconstruction_error"} for record, _ in batch):
            for item in batch:
                record, _ = item
                if record.get("status") in {"environment_error", "reconstruction_error"}:
                    record.update({"status": "queued", "accepted": None, "error": "", "error_class": ""})
                    lean_accepts_batch([item], workdir, timeout, direct_lean)
        write_progress(records, progress_output, len(all_batches), index, start_time)


def write_jsonl(rows: list[dict[str, Any]], path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for row in rows:
            f.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n")


def write_progress(
    records: list[dict[str, Any]],
    path: Path | None,
    total_batches: int,
    completed_batches: int,
    start_time: float,
) -> None:
    if path is None:
        return
    status_counts = Counter(str(record.get("status")) for record in records)
    error_class_counts = Counter(str(record.get("error_class")) for record in records if record.get("error_class"))
    accepted_counts = Counter(str(record.get("accepted")) for record in records)
    completed_records = sum(1 for record in records if record.get("status") not in {"queued", None})
    elapsed = time.time() - start_time
    rate = completed_records / elapsed if elapsed > 0 else 0.0
    remaining = len(records) - completed_records
    payload = {
        "accepted_counts": dict(accepted_counts),
        "completed_batches": completed_batches,
        "completed_records": completed_records,
        "elapsed_seconds": elapsed,
        "eta_seconds": remaining / rate if rate > 0 else None,
        "records_per_second": rate,
        "remaining_records": remaining,
        "status_counts": dict(status_counts),
        "error_class_counts": dict(error_class_counts),
        "total_batches": total_batches,
        "total_records": len(records),
        "updated_at": time.strftime("%Y-%m-%d %H:%M:%S %Z"),
    }
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(payload, ensure_ascii=False, indent=2, sort_keys=True), encoding="utf-8")


def build_records(
    rows: list[dict[str, Any]],
    strategies: list[str],
    representation: str,
    test_ratio: float,
    seed: int,
    max_k: int,
    sample_size: int,
    family_weight: float,
    family_model: str,
    execute: bool,
    workdir: Path,
    timeout: int,
    batch_size: int,
    direct_lean: bool,
    progress_output: Path | None,
) -> list[dict[str, Any]]:
    train_rows, test_rows, split_meta = theorem_split(rows, test_ratio=test_ratio, seed=seed)
    rng = random.Random(seed)
    sampled_test = list(test_rows)
    rng.shuffle(sampled_test)
    sampled_test = sampled_test[:sample_size] if sample_size else sampled_test

    records: list[dict[str, Any]] = []
    pending: list[tuple[dict[str, Any], dict[str, Any]]] = []
    for strategy in strategies:
        prepared_train, prepared_test = prepare_representation_rows(train_rows, sampled_test, representation)
        ranked = candidate_rankings(
            prepared_train,
            prepared_test,
            representation,
            strategy,
            family_weight=family_weight,
            family_model=family_model,
        )
        for query, candidates in zip(prepared_test, ranked):
            query_id = f"{query.get('file', '')}:{query.get('theorem', '')}:{query.get('step_index', 0)}"
            for rank, candidate in enumerate(candidates[:max_k], start=1):
                record = {
                    "backend": "state_reconstruction",
                    "query_id": query_id,
                    "strategy": strategy,
                    "rank": rank,
                    "k": max_k,
                    "representation": representation,
                    "seed": seed,
                    "split_strategy": split_meta["strategy"],
                    "query_file": query["file"],
                    "query_theorem": query["theorem"],
                    "query_step_index": query["step_index"],
                    "gold_family": query["tactic_family"],
                    "gold_tactic": query["next_tactic"],
                    "candidate_family": candidate["tactic_family"],
                    "candidate_tactic": candidate["next_tactic"],
                }
                record.update({"status": "queued", "accepted": None, "error": "", "error_class": ""})
                records.append(record)
                if execute:
                    pending.append((record, query))
    if execute:
        execute_pending_batches(pending, records, workdir, timeout, batch_size, direct_lean, progress_output)
    return records


def summarize(records: list[dict[str, Any]], max_k: int) -> list[dict[str, Any]]:
    grouped: dict[tuple[str, str], list[dict[str, Any]]] = {}
    for record in records:
        grouped.setdefault((record["strategy"], record["query_id"]), []).append(record)
    output = []
    for strategy in sorted({record["strategy"] for record in records}):
        query_groups = [items for (name, _), items in grouped.items() if name == strategy]
        row: dict[str, Any] = {"backend": "state_reconstruction", "strategy": strategy, "queries": len(query_groups)}
        for k in sorted({value for value in [1, 3, 5, max_k] if value <= max_k}):
            executable = 0
            hits = 0
            for items in query_groups:
                top = [item for item in items if int(item["rank"]) <= k]
                if all(is_executed_class(str(item.get("error_class", ""))) for item in top):
                    executable += 1
                if any(item.get("accepted") is True for item in top):
                    hits += 1
            row[f"lean_accept_at_{k}"] = hits / executable if executable else ""
        output.append(row)
    return output


def main() -> None:
    parser = argparse.ArgumentParser(description="Execute candidate tactics in reconstructed Lean proof states.")
    parser.add_argument("--data", type=Path, required=True)
    parser.add_argument("--strategy", default="unguided,family_guided,family_soft,family_top_m,family_rrf")
    parser.add_argument("--representation", default="state_only")
    parser.add_argument("--test-ratio", type=float, default=0.3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--sample-size", type=int, default=500)
    parser.add_argument("--max-k", type=int, default=5)
    parser.add_argument("--family-weight", type=float, default=0.25)
    parser.add_argument("--family-model", default="logistic_regression")
    parser.add_argument("--execute", action="store_true")
    parser.add_argument("--direct-lean", action="store_true")
    parser.add_argument("--timeout", type=int, default=60)
    parser.add_argument("--batch-size", type=int, default=10)
    parser.add_argument("--progress-output", type=Path)
    parser.add_argument("--workdir", type=Path, default=Path("."))
    parser.add_argument("--cache-output", type=Path, default=Path("results/execution/state_execution_cache.jsonl"))
    parser.add_argument("--summary-output", type=Path, default=Path("results/tables/state_execution_summary.jsonl"))
    args = parser.parse_args()

    rows = [normalize_entry(row, include_optional=True) for row in load_jsonl(args.data)]
    records = build_records(
        rows,
        parse_strategies(args.strategy),
        args.representation,
        args.test_ratio,
        args.seed,
        args.max_k,
        args.sample_size,
        args.family_weight,
        args.family_model,
        args.execute,
        args.workdir.resolve(),
        args.timeout,
        args.batch_size,
        args.direct_lean,
        args.progress_output,
    )
    write_jsonl(records, args.cache_output)
    write_jsonl(summarize(records, args.max_k), args.summary_output)
    action = "executed" if args.execute else "queued"
    print(f"{action} {len(records)} reconstructed-state candidates in {args.cache_output}")
    print(f"Wrote summary to: {args.summary_output}")


if __name__ == "__main__":
    main()
