from __future__ import annotations

import argparse
import hashlib
import json
import os
import re
import subprocess
import sys
import time
from pathlib import Path
from typing import Any

from .agent_settings import resolve_stage_agents_settings
from .codex_client import run_infra_plan_agent, run_infra_plan_check_agent
from .config import INFRA_LOGS_DIR, LEAN_ROOT, ROOT
from .metrics import finish_run, log_event, start_run
from .protocol import (
    INFRA_PLAN_CHECK_END,
    INFRA_PLAN_CHECK_START,
    INFRA_PLAN_END,
    INFRA_PLAN_START,
    extract_marked_json,
)

INFRA_PLAN_FILENAME = "infra_plan.json"
INFRA_PLAN_EXPAND_TEMPLATE = "infra_plan_expand_{round}.json"
INFRA_PLAN_EXPEND_TEMPLATE = "infra_plan_expend_{round}.json"
INFRA_PLAN_FAILURE_REPORT_FILENAME = "infra_plan_last_failure_report.json"
PUBLIC_API_FILENAME = "PUBLIC_API.json"
INFRA_ENTRY_FILENAME = "Main.lean"
INFRA_PLAN_MAX_ATTEMPTS = 5
INFRA_PLAN_CHECK_ROUNDS_DEFAULT = 50
PROOF_MIN_CHARS = 220
PROOF_MIN_STEPS = 3
PLAN_CHECK_FAIL_STAGNATION_MAX = 2
INFRA_EXPAND_MAX_ROUNDS_DEFAULT = 2
DIRECT_ITEM_LOOP_STATE_FILENAME = "item_loop_state.json"
DIRECT_ITEM_FAILURE_REPORTS_DIRNAME = "failure_reports"
DIRECT_ITEM_PLAN_HISTORY_DIRNAME = "plan_history"
_INFRA_DIR_RE = re.compile(r"^infra_(.+)$")
_INFRA_PLAN_EXPAND_RE = re.compile(r"^infra_plan_expand_(\d+)\.json$")
_INFRA_PLAN_EXPEND_RE = re.compile(r"^infra_plan_expend_(\d+)\.json$")
_BLOCKER_COVERAGE_FAIL_KEYWORDS = (
    "no substantive blocker reduction",
    "missing required producer theorem",
    "closure not satisfied for target blocker",
)


def _normalize_rel_to_root(path: Path, root: Path) -> Path:
    if path.is_absolute():
        try:
            return path.relative_to(root)
        except ValueError:
            return path
    if path.parts and path.parts[0] == root.name:
        return Path(*path.parts[1:])
    return path


def _normalize_rel_to_lean_root(path: Path) -> Path:
    return _normalize_rel_to_root(path, LEAN_ROOT)


def _resolve_infra_root(bench_file: Path) -> tuple[Path, Path]:
    """
    Resolve the effective Lean root for infra artifacts.

    Prefer LEAN_ROOT if it already contains Question_bench or the bench file.
    If not, but ROOT/M2F does, fall back to ROOT/M2F (common worktree layout).
    """
    bench_rel = _normalize_rel_to_root(bench_file, LEAN_ROOT)
    if (LEAN_ROOT / "Question_bench").exists() or (LEAN_ROOT / bench_rel).exists():
        return LEAN_ROOT, bench_rel
    alt_root = ROOT / "M2F"
    alt_rel = _normalize_rel_to_root(bench_file, alt_root)
    if (alt_root / "Question_bench").exists() or (alt_root / alt_rel).exists():
        return alt_root, alt_rel
    return LEAN_ROOT, bench_rel


def _bench_id(bench_file_rel: Path) -> str:
    return bench_file_rel.stem


def _infra_dir_for_bench_file(bench_file_rel: Path) -> Path:
    return bench_file_rel.parent / f"infra_{_bench_id(bench_file_rel)}"


def _infra_plan_path(bench_file_rel: Path) -> Path:
    return _infra_dir_for_bench_file(bench_file_rel) / INFRA_PLAN_FILENAME


def _infra_plan_expand_path(bench_file_rel: Path, round_idx: int) -> Path:
    return _infra_dir_for_bench_file(bench_file_rel) / INFRA_PLAN_EXPAND_TEMPLATE.format(round=int(round_idx))


def _infra_plan_expend_path(bench_file_rel: Path, round_idx: int) -> Path:
    return _infra_dir_for_bench_file(bench_file_rel) / INFRA_PLAN_EXPEND_TEMPLATE.format(round=int(round_idx))


def _infra_plan_failure_report_path(bench_file_rel: Path) -> Path:
    return _infra_dir_for_bench_file(bench_file_rel) / INFRA_PLAN_FAILURE_REPORT_FILENAME


def _plan_rel_for_round(bench_file_rel: Path, round_idx: int) -> Path:
    if int(round_idx) <= 0:
        return _infra_plan_path(bench_file_rel)
    return _infra_plan_expand_path(bench_file_rel, round_idx)


def _write_plan_failure_report(
    *,
    lean_root: Path,
    bench_file_rel: Path,
    payload: dict[str, Any],
) -> Path:
    rel = _infra_plan_failure_report_path(bench_file_rel)
    abs_path = lean_root / rel
    abs_path.parent.mkdir(parents=True, exist_ok=True)
    abs_path.write_text(json.dumps(payload, ensure_ascii=True, indent=2) + "\n", encoding="utf-8")
    return rel


def _infra_plan_draft_path(infra_dir_rel: Path, *, round_idx: int, attempt: int, source: str) -> Path:
    src = re.sub(r"[^A-Za-z0-9._-]+", "_", str(source)).strip("_") or "agent"
    return infra_dir_rel / (
        f"infra_plan_draft_round{int(round_idx)}_attempt{int(attempt)}_{src}.json"
    )


def _write_plan_draft(
    *,
    lean_root: Path,
    infra_dir_rel: Path,
    round_idx: int,
    attempt: int,
    payload: Any,
    source: str,
) -> Path | None:
    try:
        rel = _infra_plan_draft_path(
            infra_dir_rel,
            round_idx=round_idx,
            attempt=attempt,
            source=source,
        )
        abs_path = lean_root / rel
        abs_path.parent.mkdir(parents=True, exist_ok=True)
        abs_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
        return rel
    except Exception:
        return None


def _write_plan_file(path_abs: Path, items: list[dict[str, Any]]) -> None:
    path_abs.parent.mkdir(parents=True, exist_ok=True)
    path_abs.write_text(json.dumps(items, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")


def _iter_plan_entries(infra_dir_abs: Path) -> list[tuple[int, Path]]:
    plans_by_round: dict[int, Path] = {}
    base = infra_dir_abs / INFRA_PLAN_FILENAME
    if base.exists():
        plans_by_round[0] = base

    # Prefer new spelling `expand` when both exist for the same round.
    for p in sorted(infra_dir_abs.glob("infra_plan_expand_*.json")):
        m = _INFRA_PLAN_EXPAND_RE.match(p.name)
        if not m:
            continue
        plans_by_round[int(m.group(1))] = p

    # Backward compatibility for old spelling `expend`.
    for p in sorted(infra_dir_abs.glob("infra_plan_expend_*.json")):
        m = _INFRA_PLAN_EXPEND_RE.match(p.name)
        if not m:
            continue
        idx = int(m.group(1))
        plans_by_round.setdefault(idx, p)

    return sorted(plans_by_round.items(), key=lambda it: it[0])


def _iter_plan_paths(infra_dir_abs: Path) -> list[Path]:
    return [p for _, p in _iter_plan_entries(infra_dir_abs)]


def _infer_bench_and_infra_from_input(path_rel: Path) -> tuple[Path, Path, str] | None:
    parts = list(path_rel.parts)
    for i, seg in enumerate(parts):
        m = _INFRA_DIR_RE.match(seg)
        if not m:
            continue
        infra_id = (m.group(1) or "").strip()
        if not infra_id:
            continue
        infra_dir_rel = Path(*parts[: i + 1])
        bench_rel = Path(*parts[:i]) / f"{infra_id}.lean"
        return bench_rel, infra_dir_rel, infra_id
    return None


def _safe_stem(text: str) -> str:
    return re.sub(r"[^A-Za-z0-9._-]+", "_", text).strip("_") or "file"


def _join_instructions(*parts: str | None) -> str | None:
    merged = [p.strip() for p in parts if isinstance(p, str) and p.strip()]
    if not merged:
        return None
    return "\n\n".join(merged)


def _truncate_text(text: str, *, max_chars: int = 2400) -> str:
    s = (text or "").strip()
    if len(s) <= max_chars:
        return s
    return s[: max_chars - 3] + "..."


def _read_json_file(path: Path) -> Any | None:
    if not path.exists():
        return None
    try:
        return json.loads(path.read_text(encoding="utf-8"))
    except Exception:
        return None


def _load_existing_plan_items(
    infra_dir_abs: Path,
    *,
    exclude: set[Path] | None = None,
) -> list[dict[str, Any]]:
    exclude_abs = {p.resolve() for p in (exclude or set())}
    items: list[dict[str, Any]] = []
    for path in _iter_plan_paths(infra_dir_abs):
        try:
            if path.resolve() in exclude_abs:
                continue
        except Exception:
            pass
        try:
            obj = json.loads(path.read_text(encoding="utf-8"))
        except Exception:
            continue
        if not isinstance(obj, list):
            continue
        for entry in obj:
            if (
                isinstance(entry, dict)
                and isinstance(entry.get("label"), str)
                and isinstance(entry.get("env"), str)
                and isinstance(entry.get("target_file"), str)
            ):
                items.append(entry)
    return items


def _latest_missing_theory_signal(history_file: Path) -> dict[str, Any] | None:
    if not history_file.exists():
        return None
    try:
        lines = history_file.read_text(encoding="utf-8").splitlines()
    except Exception:
        return None
    for line in reversed(lines):
        line = line.strip()
        if not line:
            continue
        try:
            rec = json.loads(line)
        except Exception:
            continue
        if not isinstance(rec, dict):
            continue
        if rec.get("kind") != "agent_a_feedback":
            continue
        payload = rec.get("payload")
        if not isinstance(payload, dict):
            continue
        feedback = payload.get("feedback")
        if not isinstance(feedback, dict):
            continue
        if feedback.get("status") != "failed_missing_theory":
            continue
        reason = feedback.get("reason")
        if not isinstance(reason, str) or not reason.strip():
            continue
        return feedback
    return None


def _signal_fingerprint(signal: dict[str, Any]) -> str:
    payload = json.dumps(signal, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
    return hashlib.sha1(payload.encode("utf-8")).hexdigest()


def _build_expansion_round_instructions(
    *,
    round_idx: int,
    prior_items: list[dict[str, Any]],
    bench_file_rel: Path,
) -> str:
    prior_labels = [str(it.get("label", "")).strip() for it in prior_items]
    prior_labels = [x for x in prior_labels if x]
    label_preview = ", ".join(prior_labels[:20]) if prior_labels else "none"
    more = max(0, len(prior_labels) - 20)
    if more > 0:
        label_preview += f", ... (+{more} more)"
    return (
        f"EXPANSION ROUND #{round_idx}: you are extending existing infra for `{bench_file_rel}`.\n"
        "Generate only the additional items needed to resolve the current missing-theory blocker.\n"
        "Do not duplicate existing declarations or labels.\n"
        "Dependencies may reference earlier labels from prior plans and earlier items in this expansion plan.\n"
        "Keep new `public=true` items minimal; default to internal (`public=false`) unless directly required by the bench theorem.\n"
        "Existing labels (do not duplicate): "
        + label_preview
    )


def _generate_checked_plan_items(
    *,
    bench_file_rel: Path,
    missing_theory_signal: dict[str, Any],
    task_id: str,
    round_idx: int,
    run_id: str,
    infra_agent_settings: Any,
    infra_dir_rel: Path,
    public_api_cap: int,
    lean_root: Path,
    max_plan_check_rounds: int,
    max_plan_attempts: int,
    external_dependency_labels: set[str],
    forbidden_labels: set[str],
    existing_public_count: int,
    plan_extra_instructions: str | None = None,
) -> tuple[list[dict[str, Any]] | None, list[dict[str, Any]] | None]:
    plan_items: list[dict[str, Any]] | None = None
    last_structurally_valid_items: list[dict[str, Any]] | None = None
    extra_instructions: str | None = plan_extra_instructions
    total_attempts = max(1, int(max_plan_attempts))
    for attempt in range(1, total_attempts + 1):
        print(
            f"[infra] generating plan (round {round_idx}, attempt {attempt}/{total_attempts})..."
        )
        plan_res = run_infra_plan_agent(
            bench_file=bench_file_rel,
            missing_theory_signal=missing_theory_signal,
            task_id=task_id,
            model=infra_agent_settings.agents["PLAN"].model,
            reasoning_effort=infra_agent_settings.agents["PLAN"].reasoning_effort,
            extra_instructions=extra_instructions,
        )
        log_event(
            run_id,
            "infra_plan_agent",
            {
                "round": round_idx,
                "attempt": attempt,
                "code": plan_res.code,
                "tokens_used": plan_res.tokens_used,
                "log_path": str(plan_res.log_path) if plan_res.log_path else None,
            },
        )
        if plan_res.code != 0:
            extra_instructions = _join_instructions(
                plan_extra_instructions,
                "Your previous output failed to parse. Output valid JSON only.",
            )
            continue
        raw_plan, raw_block = extract_marked_json(plan_res.stdout or "", INFRA_PLAN_START, INFRA_PLAN_END)
        if raw_block is None:
            extra_instructions = _join_instructions(
                plan_extra_instructions,
                "Missing INFRA_PLAN markers. Wrap the JSON array with <<<INFRA_PLAN>>>.",
            )
            continue
        draft_rel = _write_plan_draft(
            lean_root=lean_root,
            infra_dir_rel=infra_dir_rel,
            round_idx=round_idx,
            attempt=attempt,
            payload=raw_plan,
            source="agent",
        )
        if draft_rel is not None:
            log_event(
                run_id,
                "infra_plan_draft_written",
                {
                    "round": round_idx,
                    "attempt": attempt,
                    "source": "agent",
                    "path": str(draft_rel),
                },
            )
        try:
            plan_items = _validate_and_normalize_plan_items(
                raw_plan,
                infra_dir_rel=infra_dir_rel,
                public_api_cap=public_api_cap,
                lean_root=lean_root,
                external_dependency_labels=external_dependency_labels,
                forbidden_labels=forbidden_labels,
                existing_public_count=existing_public_count,
            )
        except Exception as e:
            extra_instructions = _join_instructions(
                plan_extra_instructions,
                f"Plan validation failed: {e}. Please regenerate a valid plan.",
            )
            plan_items = None
            continue
        last_structurally_valid_items = plan_items

        checked_items = plan_items
        check_rounds = max(0, int(max_plan_check_rounds))
        if check_rounds > 0:
            check_extra_instructions: str | None = plan_extra_instructions
            last_issue_fp: str | None = None
            same_issue_count = 0
            for check_attempt in range(1, check_rounds + 1):
                print(
                    f"[infra] plan check (round {round_idx}, attempt {check_attempt}/{check_rounds})..."
                )
                check_res = run_infra_plan_check_agent(
                    bench_file=bench_file_rel,
                    missing_theory_signal=missing_theory_signal,
                    infra_plan=checked_items,
                    public_api_cap=public_api_cap,
                    task_id=task_id,
                    attempt=check_attempt,
                    model=infra_agent_settings.agents["CHECK"].model,
                    reasoning_effort=infra_agent_settings.agents["CHECK"].reasoning_effort,
                    extra_instructions=check_extra_instructions,
                )
                log_event(
                    run_id,
                    "infra_plan_check",
                    {
                        "round": round_idx,
                        "attempt": check_attempt,
                        "code": check_res.code,
                        "tokens_used": check_res.tokens_used,
                        "log_path": str(check_res.log_path) if check_res.log_path else None,
                    },
                )
                if check_res.code != 0:
                    extra_instructions = _join_instructions(
                        plan_extra_instructions,
                        "Plan-check agent failed to run. Please regenerate a strictly valid plan.",
                    )
                    plan_items = None
                    break

                check_payload, check_block = extract_marked_json(
                    check_res.stdout or "", INFRA_PLAN_CHECK_START, INFRA_PLAN_CHECK_END
                )
                if check_block is None or not isinstance(check_payload, dict):
                    extra_instructions = _join_instructions(
                        plan_extra_instructions,
                        "Plan-check output missing INFRA_PLAN_CHECK markers or invalid JSON. "
                        "Please regenerate a strictly valid plan.",
                    )
                    plan_items = None
                    break

                status = str(check_payload.get("status", "")).strip().lower()
                candidate_items: list[dict[str, Any]] | None = None
                if status == "ok":
                    candidate_items = checked_items
                if status == "fixed":
                    fixed_plan = check_payload.get("fixed_plan")
                    changes = check_payload.get("changes")
                    log_event(
                        run_id,
                        "infra_plan_check_fixed",
                        {"round": round_idx, "attempt": check_attempt, "changes": changes},
                    )
                    if not isinstance(fixed_plan, list):
                        extra_instructions = _join_instructions(
                            plan_extra_instructions,
                            "Plan-check returned status=fixed but no fixed_plan. Please regenerate a valid plan.",
                        )
                        plan_items = None
                        break
                    fixed_draft_rel = _write_plan_draft(
                        lean_root=lean_root,
                        infra_dir_rel=infra_dir_rel,
                        round_idx=round_idx,
                        attempt=attempt,
                        payload=fixed_plan,
                        source=f"check{check_attempt}_fixed",
                    )
                    if fixed_draft_rel is not None:
                        log_event(
                            run_id,
                            "infra_plan_draft_written",
                            {
                                "round": round_idx,
                                "attempt": attempt,
                                "source": f"check{check_attempt}_fixed",
                                "path": str(fixed_draft_rel),
                            },
                        )
                    try:
                        candidate_items = _validate_and_normalize_plan_items(
                            fixed_plan,
                            infra_dir_rel=infra_dir_rel,
                            public_api_cap=public_api_cap,
                            lean_root=lean_root,
                            external_dependency_labels=external_dependency_labels,
                            forbidden_labels=forbidden_labels,
                            existing_public_count=existing_public_count,
                        )
                        last_structurally_valid_items = candidate_items
                    except Exception as e:
                        extra_instructions = _join_instructions(
                            plan_extra_instructions,
                            f"Fixed plan validation failed: {e}. Please regenerate a valid plan.",
                        )
                        plan_items = None
                        break
                if candidate_items is not None:
                    last_structurally_valid_items = candidate_items
                    proof_issues = _collect_nondef_proof_issues(candidate_items)
                    if proof_issues:
                        issue_fp = _issues_fingerprint(proof_issues)
                        if issue_fp == last_issue_fp:
                            same_issue_count += 1
                        else:
                            last_issue_fp = issue_fp
                            same_issue_count = 1
                        issue_preview = proof_issues[:8]
                        print(
                            f"[infra] plan-check proof issues: {len(proof_issues)} item(s); "
                            f"sample={issue_preview}"
                        )
                        log_event(
                            run_id,
                            "infra_plan_check_proof_issues",
                            {
                                "round": round_idx,
                                "attempt": check_attempt,
                                "issue_count": len(proof_issues),
                                "same_issue_count": same_issue_count,
                                "fingerprint": issue_fp,
                                "sample": issue_preview,
                            },
                        )
                        if same_issue_count >= 2:
                            extra_instructions = _join_instructions(
                                plan_extra_instructions,
                                "Plan-check stagnated with the same proof-quality issues twice. "
                                "Regenerate a materially different plan decomposition with finer-grained items. "
                                "Every non-def item must include a detailed English proof with explicit steps "
                                "that is provable from dependencies closure + mathlib. "
                                f"Current proof issues: {issue_preview}.",
                            )
                            plan_items = None
                            break
                        if check_attempt >= check_rounds:
                            extra_instructions = _join_instructions(
                                plan_extra_instructions,
                                "Plan-check rounds exhausted before proof-quality constraints were met. "
                                "Regenerate a valid plan where every non-def item has detailed, stepwise proof text. "
                                f"Outstanding issues: {issue_preview}.",
                            )
                            plan_items = None
                            break
                        checked_items = candidate_items
                        check_extra_instructions = _join_instructions(
                            plan_extra_instructions,
                            "Auto-fix required: rewrite plan and set status='fixed' with full fixed_plan. "
                            "For every item with env != def/abbrev, add key `proof` as detailed English proof "
                            f"(>= {PROOF_MIN_STEPS} explicit steps, sufficiently detailed for normal proof budget) "
                            "using only dependencies closure + mathlib; avoid circular reasoning and future items. "
                            f"Outstanding issues: {issue_preview}.",
                        )
                        continue

                    plan_items = _validate_and_normalize_plan_items(
                        candidate_items,
                        infra_dir_rel=infra_dir_rel,
                        public_api_cap=public_api_cap,
                        lean_root=lean_root,
                        require_nondef_proof=True,
                        external_dependency_labels=external_dependency_labels,
                        forbidden_labels=forbidden_labels,
                        existing_public_count=existing_public_count,
                    )
                    break

                issues = check_payload.get("issues")
                summary = check_payload.get("summary") or "plan check failed"
                extra_instructions = _join_instructions(
                    plan_extra_instructions,
                    f"Plan-check failed: {summary}. Issues: {issues}. Please regenerate a valid plan.",
                )
                plan_items = None
                break

            if plan_items is None:
                continue

        break
    return plan_items, last_structurally_valid_items


def _polish_plan_with_check(
    *,
    bench_file_rel: Path,
    missing_theory_signal: dict[str, Any],
    task_id: str,
    round_idx: int,
    run_id: str,
    infra_agent_settings: Any,
    infra_dir_rel: Path,
    public_api_cap: int,
    lean_root: Path,
    plan_abs: Path,
    plan_rel: Path,
    plan_items: list[dict[str, Any]],
    max_plan_check_rounds: int,
    external_dependency_labels: set[str],
    forbidden_labels: set[str],
    existing_public_count: int,
    plan_extra_instructions: str | None = None,
) -> tuple[list[dict[str, Any]] | None, str | None, dict[str, Any] | None]:
    check_rounds = max(0, int(max_plan_check_rounds))
    current_items = plan_items
    if check_rounds <= 0:
        return current_items, None, None

    check_extra_instructions: str | None = plan_extra_instructions
    last_issue_fp: str | None = None
    same_issue_count = 0
    last_fail_fp: str | None = None
    same_fail_count = 0
    prev_status: str | None = None
    last_failure_detail: dict[str, Any] | None = None

    for check_attempt in range(1, check_rounds + 1):
        print(f"[infra] plan check (round {round_idx}, attempt {check_attempt}/{check_rounds})...")
        check_res = run_infra_plan_check_agent(
            bench_file=bench_file_rel,
            missing_theory_signal=missing_theory_signal,
            infra_plan=current_items,
            public_api_cap=public_api_cap,
            task_id=task_id,
            attempt=check_attempt,
            model=infra_agent_settings.agents["CHECK"].model,
            reasoning_effort=infra_agent_settings.agents["CHECK"].reasoning_effort,
            extra_instructions=check_extra_instructions,
        )
        log_event(
            run_id,
            "infra_plan_check",
            {
                "round": round_idx,
                "attempt": check_attempt,
                "code": check_res.code,
                "tokens_used": check_res.tokens_used,
                "log_path": str(check_res.log_path) if check_res.log_path else None,
            },
        )

        if check_res.code != 0:
            last_failure_detail = {
                "phase": "check_agent_failed",
                "attempt": check_attempt,
                "check_rounds": check_rounds,
                "code": check_res.code,
                "log_path": str(check_res.log_path) if check_res.log_path else None,
            }
            check_extra_instructions = _join_instructions(
                plan_extra_instructions,
                "Plan-check agent failed to run. Return status='fixed' with a complete fixed_plan JSON array.",
            )
            continue

        check_payload, check_block = extract_marked_json(
            check_res.stdout or "", INFRA_PLAN_CHECK_START, INFRA_PLAN_CHECK_END
        )
        if check_block is None or not isinstance(check_payload, dict):
            last_failure_detail = {
                "phase": "check_output_invalid",
                "attempt": check_attempt,
                "check_rounds": check_rounds,
            }
            check_extra_instructions = _join_instructions(
                plan_extra_instructions,
                "Plan-check output missing INFRA_PLAN_CHECK markers or invalid JSON. "
                "Return status='fixed' with full fixed_plan.",
            )
            continue

        status = str(check_payload.get("status", "")).strip().lower()
        candidate_items: list[dict[str, Any]] | None = None
        if status == "ok":
            candidate_items = current_items
        elif status == "fixed":
            fixed_plan = check_payload.get("fixed_plan")
            changes = check_payload.get("changes")
            log_event(
                run_id,
                "infra_plan_check_fixed",
                {"round": round_idx, "attempt": check_attempt, "changes": changes},
            )
            if not isinstance(fixed_plan, list):
                last_failure_detail = {
                    "phase": "fixed_plan_missing",
                    "attempt": check_attempt,
                    "check_rounds": check_rounds,
                }
                check_extra_instructions = _join_instructions(
                    plan_extra_instructions,
                    "You reported status='fixed' but omitted fixed_plan. "
                    "Return a complete fixed_plan JSON array.",
                )
                continue
            fixed_draft_rel = _write_plan_draft(
                lean_root=lean_root,
                infra_dir_rel=infra_dir_rel,
                round_idx=round_idx,
                attempt=check_attempt,
                payload=fixed_plan,
                source=f"check{check_attempt}_fixed",
            )
            if fixed_draft_rel is not None:
                log_event(
                    run_id,
                    "infra_plan_draft_written",
                    {
                        "round": round_idx,
                        "attempt": check_attempt,
                        "source": f"check{check_attempt}_fixed",
                        "path": str(fixed_draft_rel),
                    },
                )
            try:
                candidate_items = _validate_and_normalize_plan_items(
                    fixed_plan,
                    infra_dir_rel=infra_dir_rel,
                    public_api_cap=public_api_cap,
                    lean_root=lean_root,
                    external_dependency_labels=external_dependency_labels,
                    forbidden_labels=forbidden_labels,
                    existing_public_count=existing_public_count,
                )
            except Exception as e:
                last_failure_detail = {
                    "phase": "fixed_plan_validation_failed",
                    "attempt": check_attempt,
                    "check_rounds": check_rounds,
                    "error": str(e),
                }
                check_extra_instructions = _join_instructions(
                    plan_extra_instructions,
                    f"fixed_plan validation failed: {e}. Return a corrected fixed_plan.",
                )
                continue
            # Persist every successful fix immediately so reruns continue from latest polished plan.
            _write_plan_file(plan_abs, candidate_items)
            log_event(
                run_id,
                "infra_plan_polished_fixed",
                {"round": round_idx, "attempt": check_attempt, "path": str(plan_rel), "items": len(candidate_items)},
            )
            current_items = candidate_items
        else:
            summary = check_payload.get("summary") or "plan check failed"
            issues = check_payload.get("issues")
            fail_issues = _collect_plan_check_fail_issues(summary=summary, issues=issues)
            fail_fp = _issues_fingerprint(fail_issues)
            if fail_fp == last_fail_fp:
                same_fail_count += 1
            else:
                last_fail_fp = fail_fp
                same_fail_count = 1
            blocker_coverage_fail = _is_blocker_coverage_failure(fail_issues)
            log_event(
                run_id,
                "infra_plan_check_fail_issues",
                {
                    "round": round_idx,
                    "attempt": check_attempt,
                    "summary": str(summary),
                    "issues_count": len(fail_issues),
                    "same_fail_count": same_fail_count,
                    "fingerprint": fail_fp,
                    "blocker_coverage_failure": blocker_coverage_fail,
                    "sample": fail_issues[:8],
                },
            )
            last_failure_detail = {
                "phase": "check_reported_failure",
                "attempt": check_attempt,
                "check_rounds": check_rounds,
                "summary": summary,
                "issues": issues,
                "same_fail_count": same_fail_count,
                "fingerprint": fail_fp,
                "blocker_coverage_failure": blocker_coverage_fail,
            }
            if blocker_coverage_fail and prev_status == "fixed":
                return None, "plan_polish_oscillation_blocker_coverage", last_failure_detail
            if same_fail_count >= PLAN_CHECK_FAIL_STAGNATION_MAX:
                return None, "plan_polish_fail_issues_stagnated", last_failure_detail
            check_extra_instructions = _join_instructions(
                plan_extra_instructions,
                "Plan-check reported failure. Repair the existing plan and return status='fixed' with full fixed_plan.",
                f"Summary: {summary}",
                f"Issues: {issues}",
                "Do not oscillate between deleting blocker-closing items and reintroducing non-closed placeholders; "
                "if missing producer theorem direction cannot be supplied in this plan closure, return status='fail'.",
            )
            prev_status = status
            continue

        if candidate_items is None:
            continue

        # A successful ok/fixed parse resets repeated-failure tracking.
        prev_status = status
        last_fail_fp = None
        same_fail_count = 0

        proof_issues = _collect_nondef_proof_issues(candidate_items)
        if proof_issues:
            issue_fp = _issues_fingerprint(proof_issues)
            if issue_fp == last_issue_fp:
                same_issue_count += 1
            else:
                last_issue_fp = issue_fp
                same_issue_count = 1
            issue_preview = proof_issues[:8]
            log_event(
                run_id,
                "infra_plan_check_proof_issues",
                {
                    "round": round_idx,
                    "attempt": check_attempt,
                    "issue_count": len(proof_issues),
                    "same_issue_count": same_issue_count,
                    "fingerprint": issue_fp,
                    "sample": issue_preview,
                },
            )
            check_extra_instructions = _join_instructions(
                plan_extra_instructions,
                "Repair the current plan and return status='fixed' with full fixed_plan. "
                "Every env != def/abbrev item must include a detailed, stepwise proof "
                f"(>= {PROOF_MIN_STEPS} steps) provable from dependencies + mathlib only.",
                f"Outstanding proof issues: {issue_preview}",
            )
            last_failure_detail = {
                "phase": "proof_issues_remaining",
                "attempt": check_attempt,
                "check_rounds": check_rounds,
                "issue_count": len(proof_issues),
                "same_issue_count": same_issue_count,
                "fingerprint": issue_fp,
                "sample": issue_preview,
            }
            if check_attempt >= check_rounds:
                return None, "plan_polish_check_rounds_exhausted_proof_issues", last_failure_detail
            continue

        try:
            normalized = _validate_and_normalize_plan_items(
                candidate_items,
                infra_dir_rel=infra_dir_rel,
                public_api_cap=public_api_cap,
                lean_root=lean_root,
                require_nondef_proof=True,
                external_dependency_labels=external_dependency_labels,
                forbidden_labels=forbidden_labels,
                existing_public_count=existing_public_count,
            )
        except Exception as e:
            last_failure_detail = {
                "phase": "normalized_plan_validation_failed",
                "attempt": check_attempt,
                "check_rounds": check_rounds,
                "error": str(e),
            }
            check_extra_instructions = _join_instructions(
                plan_extra_instructions,
                "Repair the current plan and return status='fixed' with full fixed_plan.",
                f"Validation error: {e}",
            )
            if check_attempt >= check_rounds:
                return None, f"plan_polish_validation_exhausted:{e}", last_failure_detail
            continue

        if normalized != current_items:
            _write_plan_file(plan_abs, normalized)
            log_event(
                run_id,
                "infra_plan_polished_fixed",
                {"round": round_idx, "attempt": check_attempt, "path": str(plan_rel), "items": len(normalized)},
            )
            current_items = normalized

        if status == "ok":
            return current_items, None, None

        # status=fixed but checker has not yet confirmed "ok"; keep polishing.
        check_extra_instructions = plan_extra_instructions
        if check_attempt >= check_rounds:
            last_failure_detail = {
                "phase": "checker_never_confirmed_ok",
                "attempt": check_attempt,
                "check_rounds": check_rounds,
            }
            return None, "plan_polish_check_rounds_exhausted_without_ok", last_failure_detail

    if last_failure_detail is None:
        last_failure_detail = {
            "phase": "check_rounds_exhausted",
            "check_rounds": check_rounds,
        }
    return None, "plan_polish_check_rounds_exhausted", last_failure_detail


def _public_api_path(bench_file_rel: Path) -> Path:
    return _infra_dir_for_bench_file(bench_file_rel) / PUBLIC_API_FILENAME


def _infra_entry_path(bench_file_rel: Path) -> Path:
    return _infra_dir_for_bench_file(bench_file_rel) / INFRA_ENTRY_FILENAME


def _lean_module_segment(seg: str) -> str:
    s = (seg or "").strip()
    if not s:
        return s
    # Assume valid Lean identifier; do not quote here (infra_<id> is valid).
    return s


def _module_import_for_rel_lean_file(rel_lean_file: Path) -> str:
    if rel_lean_file.suffix != ".lean":
        raise ValueError(f"expected .lean file: {rel_lean_file}")
    parts = list(rel_lean_file.with_suffix("").parts)
    mod = ".".join(_lean_module_segment(p) for p in parts)
    return f"import {mod}"


def _ensure_import_line(file_text: str, import_line: str) -> str:
    if not import_line.strip():
        return file_text
    lines = (file_text or "").splitlines()
    normalized = import_line.strip()
    if any(line.strip() == normalized for line in lines):
        return file_text

    i = 0
    while i < len(lines):
        s = lines[i].lstrip()
        if s.startswith("import ") or s.startswith("--") or s.startswith("/-") or s.strip() == "":
            i += 1
            continue
        break

    last_import = None
    for j in range(0, i):
        if lines[j].lstrip().startswith("import "):
            last_import = j
    insert_at = (last_import + 1) if last_import is not None else 0
    new_lines = list(lines)
    new_lines.insert(insert_at, normalized)
    return "\n".join(new_lines) + ("\n" if file_text.endswith("\n") else "")


def _ensure_import_lines(file_text: str, import_lines: list[str]) -> str:
    updated = file_text
    for line in import_lines:
        updated = _ensure_import_line(updated, line)
    return updated


def _env_requires_proof(env: Any) -> bool:
    env_str = str(env or "").strip().lower()
    return env_str not in {"def", "abbrev"}


_STEP_LINE_RE = re.compile(r"^\s*(?:[-*]|(?:step\s*)?\d+[).:])\s+", re.IGNORECASE)


def _estimate_proof_steps(text: str) -> int:
    if not text:
        return 0
    lines = [line.rstrip() for line in text.splitlines() if line.strip()]
    step_lines = sum(1 for line in lines if _STEP_LINE_RE.match(line))
    if step_lines > 0:
        return step_lines
    sentence_steps = [seg for seg in re.split(r"[.;]\s+", text.strip()) if seg.strip()]
    return len(sentence_steps)


def _collect_nondef_proof_issues(items: list[dict[str, Any]]) -> list[dict[str, str]]:
    issues: list[dict[str, str]] = []
    for it in items:
        env = it.get("env")
        if not _env_requires_proof(env):
            continue
        label = str(it.get("label", "")).strip() or "<unknown>"
        proof = it.get("proof")
        if not isinstance(proof, str) or not proof.strip():
            issues.append(
                {
                    "label": label,
                    "problem": "missing_proof",
                    "detail": "env != def/abbrev items must include a non-empty English proof field.",
                }
            )
            continue
        proof_text = proof.strip()
        if len(proof_text) < PROOF_MIN_CHARS:
            issues.append(
                {
                    "label": label,
                    "problem": "proof_too_short",
                    "detail": f"proof length {len(proof_text)} < {PROOF_MIN_CHARS} chars; expand into a detailed derivation.",
                }
            )
        step_count = _estimate_proof_steps(proof_text)
        if step_count < PROOF_MIN_STEPS:
            issues.append(
                {
                    "label": label,
                    "problem": "proof_too_coarse",
                    "detail": f"estimated step count {step_count} < {PROOF_MIN_STEPS}; split into explicit proof steps.",
                }
            )
    return issues


def _issues_fingerprint(issues: list[dict[str, str]]) -> str:
    payload = json.dumps(
        sorted(issues, key=lambda it: (it.get("label", ""), it.get("problem", ""), it.get("detail", ""))),
        ensure_ascii=False,
        separators=(",", ":"),
    )
    return hashlib.sha1(payload.encode("utf-8")).hexdigest()


def _collect_plan_check_fail_issues(*, summary: Any, issues: Any) -> list[dict[str, str]]:
    normalized: list[dict[str, str]] = []
    summary_text = str(summary or "").strip()
    if summary_text:
        normalized.append(
            {
                "label": "<plan-check>",
                "problem": "summary",
                "detail": summary_text,
            }
        )
    if isinstance(issues, list):
        for issue in issues:
            if not isinstance(issue, dict):
                normalized.append(
                    {
                        "label": "<plan-check>",
                        "problem": "issue",
                        "detail": str(issue).strip(),
                    }
                )
                continue
            normalized.append(
                {
                    "label": str(issue.get("label", "")).strip() or "<plan-check>",
                    "problem": str(issue.get("problem", "")).strip() or "issue",
                    "detail": str(issue.get("detail", "")).strip(),
                }
            )
    return normalized


def _is_blocker_coverage_failure(fail_issues: list[dict[str, str]]) -> bool:
    if not fail_issues:
        return False
    blob = " ".join(
        f"{it.get('problem', '')} {it.get('detail', '')}".lower() for it in fail_issues if isinstance(it, dict)
    )
    return any(key in blob for key in _BLOCKER_COVERAGE_FAIL_KEYWORDS)


def _build_draft_imports_by_file(
    items: list[dict[str, Any]],
    *,
    external_label_to_file: dict[str, Path] | None = None,
) -> dict[Path, list[str]]:
    label_to_file: dict[str, Path] = dict(external_label_to_file or {})
    file_imports: dict[Path, set[str]] = {}

    for it in items:
        target_rel = Path(str(it.get("target_file", "")))
        file_imports.setdefault(target_rel, set())
        label = str(it.get("label", "")).strip()
        if label:
            label_to_file[label] = target_rel

    for it in items:
        target_rel = Path(str(it.get("target_file", "")))
        deps = it.get("dependencies")
        if not isinstance(deps, list):
            continue
        for dep in deps:
            dep_label = dep.strip() if isinstance(dep, str) else ""
            if not dep_label:
                continue
            dep_file = label_to_file.get(dep_label)
            if dep_file is None or dep_file == target_rel:
                continue
            file_imports[target_rel].add(_module_import_for_rel_lean_file(dep_file))

    return {p: sorted(lines) for p, lines in file_imports.items()}


def _ensure_file_container(abs_file: Path, *, draft_imports: list[str] | None = None) -> None:
    draft_imports = [line.strip() for line in (draft_imports or []) if isinstance(line, str) and line.strip()]
    abs_file.parent.mkdir(parents=True, exist_ok=True)
    if abs_file.exists():
        before = abs_file.read_text(encoding="utf-8")
        after = _ensure_import_lines(before, draft_imports)
        if after != before:
            abs_file.write_text(after, encoding="utf-8")
        return

    imports: list[str] = []
    seen: set[str] = set()
    for line in ["import Mathlib", *draft_imports]:
        if line in seen:
            continue
        seen.add(line)
        imports.append(line)

    abs_file.write_text(
        "\n".join(
            imports
            + [
                "",
                "-- Infra declarations will be appended below.",
                "",
            ]
        ),
        encoding="utf-8",
    )


def _validate_and_normalize_plan_items(
    raw_items: Any,
    *,
    infra_dir_rel: Path,
    public_api_cap: int,
    lean_root: Path,
    require_nondef_proof: bool = False,
    external_dependency_labels: set[str] | None = None,
    forbidden_labels: set[str] | None = None,
    existing_public_count: int = 0,
) -> list[dict[str, Any]]:
    if not isinstance(raw_items, list):
        raise ValueError("infra plan must be a JSON array")

    items: list[dict[str, Any]] = []
    seen_labels: set[str] = set()
    public_count = 0
    external_labels = set(external_dependency_labels or set())
    forbidden = set(forbidden_labels or set())

    for pos, entry in enumerate(raw_items, start=1):
        if not isinstance(entry, dict):
            raise ValueError(f"item[{pos}] must be an object")
        label = entry.get("label")
        env = entry.get("env")
        content = entry.get("content")
        target_file = entry.get("target_file")
        if not isinstance(label, str) or not label.strip():
            raise ValueError(f"item[{pos}] missing non-empty label")
        label = label.strip()
        if label in forbidden:
            raise ValueError(f"item[{pos}] label already exists in prior plans: {label}")
        if label in seen_labels:
            raise ValueError(f"duplicate label: {label}")
        seen_labels.add(label)
        if not isinstance(env, str) or not env.strip():
            raise ValueError(f"item[{pos}] missing non-empty env")
        if not isinstance(content, str) or not content.strip():
            raise ValueError(f"item[{pos}] missing non-empty content")
        if not isinstance(target_file, str) or not target_file.strip():
            raise ValueError(f"item[{pos}] missing non-empty target_file")

        target_rel = _normalize_rel_to_root(Path(target_file.strip()), lean_root)
        if target_rel.is_absolute():
            raise ValueError(f"item[{pos}] target_file must be relative to LEAN_ROOT")
        # If only a filename was provided, place it under infra_dir.
        if len(target_rel.parts) == 1:
            target_rel = infra_dir_rel / target_rel
        # Enforce infra directory
        try:
            target_rel.relative_to(infra_dir_rel)
        except Exception:
            raise ValueError(
                f"item[{pos}] target_file must be under {infra_dir_rel} (got {target_rel})"
            )
        if target_rel.name == INFRA_ENTRY_FILENAME:
            raise ValueError(
                f"item[{pos}] target_file must not be {INFRA_ENTRY_FILENAME} (reserved for infra entry file)"
            )

        idx = entry.get("index")
        if not isinstance(idx, int):
            idx = pos
        deps = entry.get("dependencies")
        if not isinstance(deps, list):
            deps = []
        dep_labels: list[str] = []
        for dep in deps:
            if not isinstance(dep, str) or not dep.strip():
                raise ValueError(f"item[{pos}] has invalid dependency entry: {dep!r}")
            dep_labels.append(dep.strip())
        public = entry.get("public")
        if not isinstance(public, bool):
            raise ValueError(f"item[{pos}] missing boolean field public")
        if public:
            public_count += 1
        priority = entry.get("priority")
        if not isinstance(priority, str) or not priority.strip():
            priority = "must"
        proof = entry.get("proof")
        if proof is not None and not isinstance(proof, str):
            raise ValueError(f"item[{pos}] field `proof` must be a string when present")
        env_norm = env.strip().lower()
        proof_norm = proof.strip() if isinstance(proof, str) else None
        if require_nondef_proof and _env_requires_proof(env_norm) and not (proof_norm and proof_norm.strip()):
            raise ValueError(f"item[{pos}] env={env_norm} must include non-empty `proof`")

        normalized = dict(entry)
        normalized["index"] = int(idx)
        normalized["label"] = label
        normalized["env"] = env_norm
        normalized["content"] = content
        normalized["dependencies"] = dep_labels
        normalized["target_file"] = str(target_rel)
        normalized["public"] = public
        normalized["priority"] = priority.strip()
        if proof_norm is not None:
            normalized["proof"] = proof_norm
        items.append(normalized)

    if existing_public_count + public_count > public_api_cap:
        raise ValueError(
            f"public items exceed cap: existing={existing_public_count} + new={public_count} > {public_api_cap}"
        )

    # Sort by index for stable processing
    items.sort(key=lambda it: int(it.get("index", 0)))

    label_to_index = {str(it["label"]): int(it["index"]) for it in items}
    for it in items:
        cur_idx = int(it["index"])
        for dep in it.get("dependencies", []):
            dep_idx = label_to_index.get(dep)
            if dep_idx is None:
                if dep in external_labels:
                    continue
                raise ValueError(f"item[{cur_idx}] dependency not found: {dep}")
            if dep_idx >= cur_idx:
                raise ValueError(
                    f"item[{cur_idx}] dependency must refer to earlier item: {dep} (index {dep_idx})"
                )

    return items


def _write_public_api_json(items: list[dict[str, Any]], *, path_abs: Path) -> None:
    public_items = [
        {"label": it["label"], "env": it["env"], "target_file": it["target_file"]}
        for it in items
        if it.get("public") is True
    ]
    path_abs.parent.mkdir(parents=True, exist_ok=True)
    path_abs.write_text(json.dumps(public_items, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")


def _write_infra_entry_file(*, bench_file_rel: Path, item_files: list[Path], lean_root: Path) -> None:
    entry_rel = _infra_entry_path(bench_file_rel)
    entry_abs = lean_root / entry_rel
    imports = []
    for p in sorted({Path(p) for p in item_files}):
        if p.suffix != ".lean":
            continue
        if p.name == INFRA_ENTRY_FILENAME:
            continue
        imports.append(_module_import_for_rel_lean_file(p))
    text = "\n".join(["import Mathlib", ""] + imports + [""])
    entry_abs.parent.mkdir(parents=True, exist_ok=True)
    entry_abs.write_text(text, encoding="utf-8")


def _ensure_bench_imports_infra(*, bench_file_rel: Path, lean_root: Path) -> None:
    entry_rel = _infra_entry_path(bench_file_rel)
    bench_abs = lean_root / bench_file_rel
    if not bench_abs.exists():
        return
    import_line = _module_import_for_rel_lean_file(entry_rel)
    before = bench_abs.read_text(encoding="utf-8")
    after = _ensure_import_line(before, import_line)
    if after != before:
        bench_abs.write_text(after, encoding="utf-8")


def _run_subprocess(module: str, args: list[str], env_overrides: dict[str, str] | None = None) -> int:
    cmd = [sys.executable, "-m", module] + args
    env = os.environ.copy()
    if env_overrides:
        env.update(env_overrides)
    result = subprocess.run(cmd, cwd=ROOT, env=env)
    return int(result.returncode)


def _run_statement_stage(
    *,
    plan_path: Path,
    project: str,
    task_id: str,
    max_b_retries: int,
    env_overrides: dict[str, str] | None,
) -> bool:
    code = _run_subprocess(
        "orchestrator.item_statement_pipeline",
        [
            "--data-file",
            str(plan_path),
            "--project",
            project,
            "--task-id",
            task_id,
            "--stage",
            "infra_statement",
            "--semantic-check",
            "--max-b-retries",
            str(max_b_retries),
        ],
        env_overrides=env_overrides,
    )
    return code == 0


def _run_proof_stage(
    *,
    plan_path: Path,
    project: str,
    task_id: str,
    max_b_retries: int,
    max_c_replans: int,
    env_overrides: dict[str, str] | None,
) -> bool:
    code = _run_subprocess(
        "orchestrator.item_proof_pipeline",
        [
            "--data-file",
            str(plan_path),
            "--project",
            project,
            "--task-id",
            task_id,
            "--stage",
            "infra_proof",
            "--max-b-retries",
            str(max_b_retries),
            "--max-c-replans",
            str(max_c_replans),
        ],
        env_overrides=env_overrides,
    )
    return code == 0


def _load_direct_blocker(
    *,
    lean_root: Path,
    infra_dir_abs: Path,
) -> dict[str, Any] | None:
    state_abs = infra_dir_abs / DIRECT_ITEM_LOOP_STATE_FILENAME
    state = _read_json_file(state_abs)
    state_last_error = state.get("last_error") if isinstance(state, dict) and isinstance(state.get("last_error"), dict) else None
    fail_index: int | None = None
    reason: str | None = None
    report_abs: Path | None = None
    if isinstance(state_last_error, dict):
        idx = state_last_error.get("index")
        if isinstance(idx, int):
            fail_index = int(idx)
        r = state_last_error.get("reason")
        if isinstance(r, str) and r.strip():
            reason = r.strip()
        fr = state_last_error.get("failure_report")
        if isinstance(fr, str) and fr.strip():
            p = Path(fr.strip())
            report_abs = (p if p.is_absolute() else (lean_root / p))

    report_payload: dict[str, Any] | None = None
    if report_abs is not None and report_abs.exists():
        raw = _read_json_file(report_abs)
        if isinstance(raw, dict):
            report_payload = raw
    if report_payload is None:
        report_dir = infra_dir_abs / DIRECT_ITEM_FAILURE_REPORTS_DIRNAME
        latest_payload: dict[str, Any] | None = None
        latest_path: Path | None = None
        if report_dir.exists():
            candidates = []
            for p in report_dir.glob("failure_idx*.json"):
                try:
                    candidates.append((p.stat().st_mtime, p))
                except Exception:
                    continue
            for _, p in sorted(candidates, key=lambda it: it[0], reverse=True):
                raw = _read_json_file(p)
                if isinstance(raw, dict):
                    latest_path = p
                    latest_payload = raw
                    break
        report_abs = latest_path or report_abs
        report_payload = latest_payload

    failed_item = report_payload.get("failed_item") if isinstance(report_payload, dict) and isinstance(report_payload.get("failed_item"), dict) else None
    if fail_index is None and isinstance(failed_item, dict):
        idx2 = failed_item.get("index")
        if isinstance(idx2, int):
            fail_index = int(idx2)
    if reason is None and isinstance(report_payload, dict):
        rr = report_payload.get("reason")
        if isinstance(rr, str) and rr.strip():
            reason = rr.strip()

    if fail_index is None:
        return None

    report_rel = None
    if report_abs is not None:
        try:
            report_rel = str(report_abs.relative_to(lean_root))
        except Exception:
            report_rel = str(report_abs)
    return {
        "failed_index": int(fail_index),
        "reason": str(reason or "unknown"),
        "failed_item": failed_item if isinstance(failed_item, dict) else {},
        "report": report_payload if isinstance(report_payload, dict) else {},
        "report_path": report_rel,
        "state": state if isinstance(state, dict) else {},
    }


def _ensure_prefix_frozen(
    *,
    before_items: list[dict[str, Any]],
    after_items: list[dict[str, Any]],
    fail_index: int,
) -> None:
    before_map = {
        int(it.get("index", 0)): it
        for it in before_items
        if isinstance(it, dict) and isinstance(it.get("index"), int)
    }
    after_map = {
        int(it.get("index", 0)): it
        for it in after_items
        if isinstance(it, dict) and isinstance(it.get("index"), int)
    }
    for idx in sorted(i for i in before_map.keys() if int(i) < int(fail_index)):
        if idx not in after_map:
            raise ValueError(f"prefix item index {idx} missing after suffix replan")
        if before_map[idx] != after_map[idx]:
            b_label = str(before_map[idx].get("label", ""))
            a_label = str(after_map[idx].get("label", ""))
            raise ValueError(
                f"prefix item index {idx} changed (before label={b_label!r}, after label={a_label!r})"
            )


def _archive_plan_snapshot(
    *,
    infra_dir_abs: Path,
    plan_abs: Path,
    fail_index: int,
    attempt_no: int,
) -> Path | None:
    try:
        src = plan_abs.read_text(encoding="utf-8")
    except Exception:
        return None
    try:
        hist_dir = infra_dir_abs / DIRECT_ITEM_PLAN_HISTORY_DIRNAME
        hist_dir.mkdir(parents=True, exist_ok=True)
        ts = time.strftime("%Y%m%dT%H%M%SZ", time.gmtime())
        name = f"plan_before_suffix_idx{int(fail_index):04d}_attempt{int(attempt_no):03d}_{ts}.json"
        dst = hist_dir / name
        dst.write_text(src, encoding="utf-8")
        return dst
    except Exception:
        return None


def _attempt_direct_suffix_replan(
    *,
    bench_file_rel: Path,
    missing_theory_signal: dict[str, Any],
    task_id: str,
    run_id: str,
    round_idx: int,
    attempt_no: int,
    max_attempts: int,
    infra_agent_settings: Any,
    infra_dir_rel: Path,
    infra_dir_abs: Path,
    lean_root: Path,
    plan_abs: Path,
    plan_rel: Path,
    before_items: list[dict[str, Any]],
    fail_index: int,
    blocker: dict[str, Any],
    public_api_cap: int,
    external_dependency_labels: set[str],
    forbidden_labels: set[str],
    existing_public_count: int,
) -> tuple[list[dict[str, Any]] | None, str]:
    failed_item = blocker.get("failed_item") if isinstance(blocker.get("failed_item"), dict) else {}
    report_payload = blocker.get("report") if isinstance(blocker.get("report"), dict) else {}
    reason = str(blocker.get("reason", "unknown"))
    report_path = str(blocker.get("report_path", "")).strip()
    report_detail = report_payload.get("detail") if isinstance(report_payload.get("detail"), dict) else {}
    diagnosis = report_detail.get("blocker_diagnosis") if isinstance(report_detail, dict) and isinstance(report_detail.get("blocker_diagnosis"), dict) else {}
    failure_class = str(diagnosis.get("failure_class", "")).strip()
    diagnosis_text = str(diagnosis.get("diagnosis", "")).strip()
    planner_guidance = str(diagnosis.get("planner_guidance", "")).strip()
    evidence_lines = diagnosis.get("evidence_lines") if isinstance(diagnosis.get("evidence_lines"), list) else []
    evidence_preview = []
    for line in evidence_lines:
        if isinstance(line, str) and line.strip():
            evidence_preview.append(line.strip())
        if len(evidence_preview) >= 6:
            break
    sorry_ctx = report_detail.get("sorry_context") if isinstance(report_detail.get("sorry_context"), dict) else {}
    sorry_count = sorry_ctx.get("sorry_count_total")
    if not isinstance(sorry_count, int):
        sorry_count = None
    detail_excerpt = _truncate_text(
        json.dumps(report_payload.get("detail", {}), ensure_ascii=False),
        max_chars=1600,
    )
    diagnosis_evidence_block = (
        "\n".join(f"  - {line}" for line in evidence_preview)
        if evidence_preview
        else "  - <none>"
    )
    suffix_instructions = "\n".join(
        [
            "DIRECT-ITEM FAILURE FEEDBACK (MUST APPLY):",
            f"- failed_index: {int(fail_index)}",
            f"- failed_label: {failed_item.get('label', '')}",
            f"- failed_item_id: {failed_item.get('item_id', '')}",
            f"- blocked_reason: {reason}",
            (f"- failure_report: {report_path}" if report_path else "- failure_report: <none>"),
            (f"- failure_class: {failure_class}" if failure_class else "- failure_class: <unknown>"),
            (f"- diagnosis: {diagnosis_text}" if diagnosis_text else "- diagnosis: <none>"),
            (f"- planner_guidance: {planner_guidance}" if planner_guidance else "- planner_guidance: <none>"),
            (
                f"- current_item_sorry_count: {int(sorry_count)}"
                if isinstance(sorry_count, int)
                else "- current_item_sorry_count: <unknown>"
            ),
            "- diagnosis_evidence_lines:",
            diagnosis_evidence_block,
            f"- failure_detail_excerpt: {detail_excerpt}",
            "",
            "HARD CONSTRAINTS:",
            "1) Rewrite suffix only: only items with `index >= failed_index` may change.",
            "2) Frozen prefix: items with `index < failed_index` must remain byte-identical in all fields.",
            "3) Keep dependencies acyclic and earlier-only; suffix may depend on frozen prefix labels.",
            "4) Keep item_id and label stable when still semantically valid.",
            "5) Output status=`fixed` with a full-plan `fixed_plan` array.",
            "",
            "DIAGNOSIS-AWARE FIX RULES:",
            "- If failure_class=math_incorrect_or_unprovable: revise the failed statement/hypotheses directly, do not only rename.",
            "- If failure_class=missing_theory_or_library_gap: insert prerequisite bridge lemmas before the failed item.",
            "- If failure_class contains decomposition/plan_gap/coarse/stalled: split failed item into finer lemmas and rewire dependencies.",
        ]
    )
    print(
        f"[infra][direct] suffix replan attempt {attempt_no}/{max_attempts} "
        f"from index={int(fail_index)}"
    )
    check_res = run_infra_plan_check_agent(
        bench_file=bench_file_rel,
        missing_theory_signal=missing_theory_signal,
        infra_plan=before_items,
        public_api_cap=int(public_api_cap),
        task_id=f"{task_id}_direct_suffix_{int(fail_index)}",
        attempt=int(attempt_no),
        model=infra_agent_settings.agents["CHECK"].model,
        reasoning_effort=infra_agent_settings.agents["CHECK"].reasoning_effort,
        extra_instructions=suffix_instructions,
    )
    log_event(
        run_id,
        "infra_direct_suffix_replan_attempt",
        {
            "round": int(round_idx),
            "attempt": int(attempt_no),
            "max_attempts": int(max_attempts),
            "failed_index": int(fail_index),
            "code": check_res.code,
            "tokens_used": check_res.tokens_used,
            "log_path": str(check_res.log_path) if check_res.log_path else None,
        },
    )
    if check_res.code != 0:
        return None, f"suffix_replan_agent_failed_code_{check_res.code}"

    check_payload, check_block = extract_marked_json(
        check_res.stdout or "", INFRA_PLAN_CHECK_START, INFRA_PLAN_CHECK_END
    )
    if check_block is None or not isinstance(check_payload, dict):
        return None, "suffix_replan_invalid_check_output"
    status = str(check_payload.get("status", "")).strip().lower()
    if status != "fixed":
        return None, f"suffix_replan_status_{status or 'missing'}"
    fixed_plan = check_payload.get("fixed_plan")
    if not isinstance(fixed_plan, list):
        return None, "suffix_replan_missing_fixed_plan"

    fixed_draft_rel = _write_plan_draft(
        lean_root=lean_root,
        infra_dir_rel=infra_dir_rel,
        round_idx=int(round_idx),
        attempt=int(attempt_no),
        payload=fixed_plan,
        source=f"direct_suffix_replan_{int(fail_index)}",
    )
    if fixed_draft_rel is not None:
        log_event(
            run_id,
            "infra_plan_draft_written",
            {
                "round": int(round_idx),
                "attempt": int(attempt_no),
                "source": f"direct_suffix_replan_{int(fail_index)}",
                "path": str(fixed_draft_rel),
            },
        )

    try:
        candidate_items = _validate_and_normalize_plan_items(
            fixed_plan,
            infra_dir_rel=infra_dir_rel,
            public_api_cap=int(public_api_cap),
            lean_root=lean_root,
            external_dependency_labels=external_dependency_labels,
            forbidden_labels=forbidden_labels,
            existing_public_count=existing_public_count,
        )
    except Exception as e:
        return None, f"suffix_replan_validation_failed: {e}"

    try:
        _ensure_prefix_frozen(
            before_items=before_items,
            after_items=candidate_items,
            fail_index=int(fail_index),
        )
    except Exception as e:
        return None, f"suffix_replan_prefix_changed: {e}"

    snapshot = _archive_plan_snapshot(
        infra_dir_abs=infra_dir_abs,
        plan_abs=plan_abs,
        fail_index=int(fail_index),
        attempt_no=int(attempt_no),
    )
    if snapshot is not None:
        log_event(
            run_id,
            "infra_direct_suffix_snapshot",
            {"round": int(round_idx), "attempt": int(attempt_no), "path": str(snapshot)},
        )

    print(
        f"[infra][direct] suffix replan accepted at attempt {attempt_no}/{max_attempts}; "
        f"plan={plan_rel}"
    )
    return candidate_items, "ok"


def _run_direct_item_stage(
    *,
    bench_file_rel: Path,
    plan_path: Path,
    task_id: str,
    env_overrides: dict[str, str] | None,
    simulate_success: bool,
    chunk_item_limit: int,
    chunk_line_limit: int,
    max_items: int | None,
    start_index: int | None,
    statement_max_b_retries: int,
    proof_max_b_retries: int,
    proof_max_c_replans: int,
) -> bool:
    args = [
        "--bench-file",
        str(bench_file_rel),
        "--plan-file",
        str(plan_path),
        "--task-id",
        task_id,
        "--chunk-item-limit",
        str(int(chunk_item_limit)),
        "--chunk-line-limit",
        str(int(chunk_line_limit)),
        "--statement-max-b-retries",
        str(int(statement_max_b_retries)),
        "--proof-max-b-retries",
        str(int(proof_max_b_retries)),
        "--proof-max-c-replans",
        str(int(proof_max_c_replans)),
    ]
    if simulate_success:
        args.append("--simulate-success")
    if max_items is not None:
        args.extend(["--max-items", str(int(max_items))])
    if start_index is not None:
        args.extend(["--start-index", str(int(start_index))])
    code = _run_subprocess(
        "orchestrator.infra_item_direct_pipeline",
        args,
        env_overrides=env_overrides,
    )
    return code == 0


def _run_final_stage_on_files(
    *,
    files: list[Path],
    lean_root: Path,
    env_overrides: dict[str, str] | None,
    task_id: str,
    round_idx: int,
) -> tuple[bool, dict[str, Any] | None]:
    trace_dir = INFRA_LOGS_DIR / "infra_final_runs" / _safe_stem(task_id) / f"round_{round_idx:02d}"
    trace_dir.mkdir(parents=True, exist_ok=True)
    for rel in files:
        if rel.suffix != ".lean":
            continue
        abs_file = lean_root / rel
        if not abs_file.exists():
            continue
        text = abs_file.read_text(encoding="utf-8")
        if "sorry" not in text:
            continue
        stem = _safe_stem(rel.with_suffix("").as_posix())
        history_file = trace_dir / f"{stem}.history.jsonl"
        progress_file = trace_dir / f"{stem}.progress.json"
        code = _run_subprocess(
            "orchestrator.final_pipeline",
            [
                "--only-file",
                str(rel),
                "--history-file",
                str(history_file),
                "--progress-file",
                str(progress_file),
                "--force-stage",
                "infra",
                "--missing-theory-policy",
                "continue",
                "--auto-infra-sprint",
                "--scaffold-sorry-budget",
                "0",
            ],
            env_overrides=env_overrides,
        )
        if code != 0:
            signal = _latest_missing_theory_signal(history_file)
            if signal is not None:
                reason = str(signal.get("reason", "")).strip()
                print(
                    f"[infra] final sweep blocked by missing theory in {rel}: "
                    + (reason[:300] if reason else "<no reason>")
                )
            return False, signal
    return True, None


def run_infra_pipeline(
    *,
    bench_file: Path,
    missing_theory_signal: dict[str, Any],
    infra_public_api_cap: int,
    max_b_retries: int,
    infra_statement_max_b_retries: int | None,
    max_c_replans: int,
    infra_plan_generate_attempts: int,
    max_plan_check_rounds: int,
    infra_expand_max_rounds: int = INFRA_EXPAND_MAX_ROUNDS_DEFAULT,
    infra_agent_config: Path | None = None,
    infra_exec_mode: str = "direct_item",
    infra_direct_simulate_success: bool = False,
    infra_direct_chunk_item_limit: int = 20,
    infra_direct_chunk_line_limit: int = 1800,
    infra_direct_max_items: int | None = None,
    infra_direct_start_index: int | None = None,
    infra_direct_statement_max_b_retries: int = 3,
    infra_direct_proof_max_b_retries: int = 3,
    infra_direct_proof_max_c_replans: int = 1,
) -> bool:
    run_start = time.monotonic()
    written_plan_paths: list[str] = []
    rounds_completed = 0
    statement_max_b_retries = (
        int(infra_statement_max_b_retries)
        if infra_statement_max_b_retries is not None
        else int(max_b_retries)
    )

    def _finish(status: str, *, reason: str | None = None, items: int | None = None) -> None:
        summary: dict[str, Any] = {
            "pipeline": "infra_pipeline",
            "stage": 0,
            "run_id": run_id,
            "status": status,
            "task_id": task_id,
            "bench_file": str(bench_file_rel),
            "seconds_total": time.monotonic() - run_start,
            "paths": {
                "infra_dir": str(infra_dir_rel),
                "plan_path": str(_infra_plan_path(bench_file_rel)),
                "public_api_path": str(_public_api_path(bench_file_rel)),
                "entry_path": str(_infra_entry_path(bench_file_rel)),
                "plan_paths": written_plan_paths,
            },
            "rounds_completed": rounds_completed,
        }
        if reason:
            summary["reason"] = reason
        if items is not None:
            summary["items"] = int(items)
        finish_run(run_id, summary)

    lean_root, input_rel = _resolve_infra_root(bench_file)
    bench_file_rel = input_rel
    direct_infra_mode = False
    inferred = _infer_bench_and_infra_from_input(input_rel)
    if inferred is not None:
        bench_file_rel, infra_dir_rel, _ = inferred
        direct_infra_mode = True
    else:
        infra_dir_rel = _infra_dir_for_bench_file(bench_file_rel)
    infra_dir_abs = lean_root / infra_dir_rel
    infra_dir_abs.mkdir(parents=True, exist_ok=True)
    env_overrides = None
    try:
        rel_to_root = lean_root.relative_to(ROOT)
        env_overrides = {"LEAN_PROJECT_DIR": rel_to_root.as_posix()}
    except Exception:
        env_overrides = None

    default_cfg = ROOT / "agent_configs/infra_agents.toml"
    cfg_from_env = Path(os.environ["INFRA_AGENT_CONFIG_FILE"]) if os.environ.get("INFRA_AGENT_CONFIG_FILE") else None
    infra_cfg = infra_agent_config or cfg_from_env
    infra_agent_settings = resolve_stage_agents_settings(
        stage_prefix="INFRA_AGENT",
        agent_keys=["PLAN", "CHECK"],
        config_path=infra_cfg,
        default_config_path=(default_cfg if default_cfg.exists() else None),
    )

    task_id = f"infra_{_bench_id(bench_file_rel)}"
    run_id = start_run(
        "infra_pipeline",
        stage=0,
        name_tag=task_id,
        data_file=str(_infra_plan_path(bench_file_rel)),
        extra={
            "bench_file": str(bench_file_rel),
            "input_path": str(input_rel),
            "direct_infra_mode": bool(direct_infra_mode),
            "infra_agent_config": str(infra_agent_settings.source_path) if infra_agent_settings.source_path else None,
            "infra_plan_model": infra_agent_settings.agents["PLAN"].model,
            "infra_plan_reasoning_effort": infra_agent_settings.agents["PLAN"].reasoning_effort,
            "infra_plan_generate_attempts": int(max(1, int(infra_plan_generate_attempts))),
            "infra_plan_check_model": infra_agent_settings.agents["CHECK"].model,
            "infra_plan_check_reasoning_effort": infra_agent_settings.agents["CHECK"].reasoning_effort,
            "infra_expand_max_rounds": int(max(0, int(infra_expand_max_rounds))),
            "infra_exec_mode": str(infra_exec_mode),
            "infra_direct_simulate_success": bool(infra_direct_simulate_success),
            "infra_direct_chunk_item_limit": int(infra_direct_chunk_item_limit),
            "infra_direct_chunk_line_limit": int(infra_direct_chunk_line_limit),
            "infra_direct_max_items": (
                int(infra_direct_max_items)
                if infra_direct_max_items is not None
                else None
            ),
            "infra_direct_start_index": (
                int(infra_direct_start_index)
                if infra_direct_start_index is not None
                else None
            ),
            "infra_direct_statement_max_b_retries": int(infra_direct_statement_max_b_retries),
            "infra_direct_proof_max_b_retries": int(infra_direct_proof_max_b_retries),
            "infra_direct_proof_max_c_replans": int(infra_direct_proof_max_c_replans),
        },
    )

    def _report_plan_failure(
        *,
        reason: str,
        round_idx: int | None,
        plan_rel: Path | None,
        detail: dict[str, Any] | None = None,
        items: int | None = None,
    ) -> Path | None:
        payload: dict[str, Any] = {
            "timestamp_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
            "run_id": run_id,
            "task_id": task_id,
            "bench_file": str(bench_file_rel),
            "reason": reason,
        }
        if round_idx is not None:
            payload["round"] = int(round_idx)
        if plan_rel is not None:
            payload["plan_path"] = str(plan_rel)
        if items is not None:
            payload["items"] = int(items)
        if detail:
            payload["detail"] = detail
        try:
            report_rel = _write_plan_failure_report(
                lean_root=lean_root,
                bench_file_rel=bench_file_rel,
                payload=payload,
            )
            log_event(
                run_id,
                "infra_plan_failure_report_written",
                {
                    "reason": reason,
                    "round": round_idx,
                    "plan_path": str(plan_rel) if plan_rel is not None else None,
                    "report_path": str(report_rel),
                },
            )
            return report_rel
        except Exception as e:
            log_event(
                run_id,
                "infra_plan_failure_report_write_failed",
                {
                    "reason": reason,
                    "round": round_idx,
                    "plan_path": str(plan_rel) if plan_rel is not None else None,
                    "error": str(e),
                },
            )
            return None

    project = task_id
    max_expand_rounds = max(0, int(infra_expand_max_rounds))
    next_signal: dict[str, Any] | None = missing_theory_signal
    seen_blockers: set[str] = set()
    total_items = 0
    existing_entries = _iter_plan_entries(infra_dir_abs)
    existing_rounds = [idx for idx, _ in existing_entries if idx > 0]
    base_exists = any(idx == 0 for idx, _ in existing_entries)
    resume_round: int | None = existing_rounds[-1] if (direct_infra_mode and existing_rounds) else None
    next_generate_round = 0
    max_generated_rounds = max_expand_rounds + 1
    if direct_infra_mode:
        if existing_rounds:
            next_generate_round = existing_rounds[-1] + 1
            max_generated_rounds = max_expand_rounds
        elif base_exists:
            # Existing base infra plan present; first new generation should be k=1.
            next_generate_round = 1
            max_generated_rounds = max_expand_rounds
        else:
            next_generate_round = 0
            max_generated_rounds = max_expand_rounds + 1
    generated_rounds = 0

    while True:
        using_existing_round = resume_round is not None
        if using_existing_round:
            round_idx = int(resume_round)
            resume_round = None
            entries_map = dict(_iter_plan_entries(infra_dir_abs))
            plan_abs = entries_map.get(round_idx)
            if plan_abs is None:
                _report_plan_failure(
                    reason=f"missing_existing_plan_round_{round_idx}",
                    round_idx=round_idx,
                    plan_rel=None,
                    items=total_items,
                )
                _finish("failed", reason=f"missing_existing_plan_round_{round_idx}", items=total_items)
                return False
            try:
                plan_rel = plan_abs.relative_to(lean_root)
            except Exception:
                plan_rel = plan_abs
        else:
            if generated_rounds >= max_generated_rounds:
                _finish("failed", reason="infra_expand_exhausted", items=total_items)
                return False
            if next_signal is None:
                _finish("failed", reason="missing_theory_signal_unavailable", items=total_items)
                return False
            round_idx = int(next_generate_round)
            next_generate_round += 1
            generated_rounds += 1
            plan_rel = _plan_rel_for_round(bench_file_rel, round_idx)
            plan_abs = lean_root / plan_rel

        rounds_completed += 1
        round_task_id = task_id if round_idx == 0 else f"{task_id}_expend_{round_idx}"
        prior_items = _load_existing_plan_items(infra_dir_abs, exclude={plan_abs})
        prior_labels = {
            str(it.get("label", "")).strip()
            for it in prior_items
            if isinstance(it, dict) and isinstance(it.get("label"), str) and str(it.get("label")).strip()
        }
        prior_public_count = sum(1 for it in prior_items if isinstance(it, dict) and it.get("public") is True)

        round_instructions = None
        if round_idx > 0:
            round_instructions = _build_expansion_round_instructions(
                round_idx=round_idx,
                prior_items=prior_items,
                bench_file_rel=bench_file_rel,
            )
        plan_file_exists = plan_abs.exists()

        log_event(
            run_id,
            "infra_round_start",
            {
                "round": round_idx,
                "round_mode": (
                    "resume_existing"
                    if using_existing_round
                    else ("reuse_for_polish" if plan_file_exists else "generate")
                ),
                "task_id": round_task_id,
                "plan_path": str(plan_rel),
                "prior_items": len(prior_items),
                "prior_public_items": prior_public_count,
            },
        )

        if using_existing_round and not plan_file_exists:
            _report_plan_failure(
                reason=f"missing_existing_plan_round_{round_idx}",
                round_idx=round_idx,
                plan_rel=plan_rel,
                items=total_items,
            )
            _finish("failed", reason=f"missing_existing_plan_round_{round_idx}", items=total_items)
            return False

        if plan_file_exists:
            try:
                raw_plan = json.loads(plan_abs.read_text(encoding="utf-8"))
            except Exception as e:
                _report_plan_failure(
                    reason=f"existing_plan_unreadable_round_{round_idx}:{e}",
                    round_idx=round_idx,
                    plan_rel=plan_rel,
                    detail={"error": str(e)},
                    items=total_items,
                )
                _finish("failed", reason=f"existing_plan_unreadable_round_{round_idx}:{e}", items=total_items)
                return False
            try:
                plan_items = _validate_and_normalize_plan_items(
                    raw_plan,
                    infra_dir_rel=infra_dir_rel,
                    public_api_cap=infra_public_api_cap,
                    lean_root=lean_root,
                    external_dependency_labels=prior_labels,
                    forbidden_labels=prior_labels,
                    existing_public_count=prior_public_count,
                )
            except Exception as e:
                _report_plan_failure(
                    reason=f"existing_plan_invalid_round_{round_idx}:{e}",
                    round_idx=round_idx,
                    plan_rel=plan_rel,
                    detail={"error": str(e)},
                    items=total_items,
                )
                _finish("failed", reason=f"existing_plan_invalid_round_{round_idx}:{e}", items=total_items)
                return False
            log_event(
                run_id,
                "infra_plan_reused",
                {"round": round_idx, "path": str(plan_rel), "items": len(plan_items), "for_polish": True},
            )
        else:
            plan_items, fallback_plan_items = _generate_checked_plan_items(
                bench_file_rel=bench_file_rel,
                missing_theory_signal=next_signal,
                task_id=round_task_id,
                round_idx=round_idx,
                run_id=run_id,
                infra_agent_settings=infra_agent_settings,
                infra_dir_rel=infra_dir_rel,
                public_api_cap=infra_public_api_cap,
                lean_root=lean_root,
                max_plan_check_rounds=0,
                max_plan_attempts=int(max(1, int(infra_plan_generate_attempts))),
                external_dependency_labels=prior_labels,
                forbidden_labels=prior_labels,
                existing_public_count=prior_public_count,
                plan_extra_instructions=round_instructions,
            )
            if not plan_items:
                if fallback_plan_items:
                    _write_plan_file(plan_abs, fallback_plan_items)
                    if str(plan_rel) not in written_plan_paths:
                        written_plan_paths.append(str(plan_rel))
                    log_event(
                        run_id,
                        "infra_plan_written_retry_exhausted",
                        {
                            "round": round_idx,
                            "path": str(plan_rel),
                            "items": len(fallback_plan_items),
                            "status": "manual_review_required",
                        },
                    )
                    fallback_total_items = len(_load_existing_plan_items(infra_dir_abs))
                    print(
                        "[infra] plan retries exhausted; wrote last structurally valid plan "
                        f"to {plan_rel}. Exiting for manual review."
                    )
                    _report_plan_failure(
                        reason="plan_generation_retries_exhausted_plan_written_manual_review",
                        round_idx=round_idx,
                        plan_rel=plan_rel,
                        detail={
                            "status": "manual_review_required",
                            "fallback_items": len(fallback_plan_items),
                        },
                        items=fallback_total_items,
                    )
                    _finish(
                        "failed",
                        reason="plan_generation_retries_exhausted_plan_written_manual_review",
                        items=fallback_total_items,
                    )
                    return False

                print("[infra] failed to generate a valid infra plan; aborting.")
                _report_plan_failure(
                    reason="plan_generation_failed",
                    round_idx=round_idx,
                    plan_rel=plan_rel,
                    items=total_items,
                )
                _finish("failed", reason="plan_generation_failed", items=total_items)
                return False

            _write_plan_file(plan_abs, plan_items)
            log_event(
                run_id,
                "infra_plan_written_initial",
                {"round": round_idx, "path": str(plan_rel), "items": len(plan_items)},
            )

        if str(plan_rel) not in written_plan_paths:
            written_plan_paths.append(str(plan_rel))

        plan_items, polish_fail_reason, polish_fail_detail = _polish_plan_with_check(
            bench_file_rel=bench_file_rel,
            missing_theory_signal=next_signal,
            task_id=round_task_id,
            round_idx=round_idx,
            run_id=run_id,
            infra_agent_settings=infra_agent_settings,
            infra_dir_rel=infra_dir_rel,
            public_api_cap=infra_public_api_cap,
            lean_root=lean_root,
            plan_abs=plan_abs,
            plan_rel=plan_rel,
            plan_items=plan_items,
            max_plan_check_rounds=max_plan_check_rounds,
            external_dependency_labels=prior_labels,
            forbidden_labels=prior_labels,
            existing_public_count=prior_public_count,
            plan_extra_instructions=round_instructions,
        )
        if plan_items is None:
            latest_total_items = len(_load_existing_plan_items(infra_dir_abs))
            fail_reason = polish_fail_reason or "plan_polish_failed_with_latest_preserved"
            report_rel = _report_plan_failure(
                reason=fail_reason,
                round_idx=round_idx,
                plan_rel=plan_rel,
                detail=polish_fail_detail,
                items=latest_total_items,
            )
            print(
                "[infra][report] plan check/fix failed; exiting immediately. "
                f"reason={fail_reason} plan={plan_rel} "
                + (f"failure_report={report_rel}" if report_rel is not None else "")
            )
            log_event(
                run_id,
                "infra_plan_polish_failed",
                {
                    "round": round_idx,
                    "reason": fail_reason,
                    "plan_path": str(plan_rel),
                    "items": latest_total_items,
                    "detail": polish_fail_detail,
                    "failure_report_path": str(report_rel) if report_rel is not None else None,
                },
            )
            _finish(
                "failed",
                reason=fail_reason,
                items=latest_total_items,
            )
            return False

        all_items = _load_existing_plan_items(infra_dir_abs)
        total_items = len(all_items)
        _write_public_api_json(all_items, path_abs=lean_root / _public_api_path(bench_file_rel))
        log_event(
            run_id,
            "infra_public_api_written",
            {"path": str(_public_api_path(bench_file_rel)), "public_count": sum(1 for it in all_items if it.get("public") is True)},
        )

        if str(infra_exec_mode).strip().lower() == "direct_item":
            direct_plan_attempt_cap = max(0, int(infra_plan_generate_attempts))
            direct_plan_attempt_used = 0
            direct_exec_attempt = 0
            direct_start_index = (
                int(infra_direct_start_index)
                if infra_direct_start_index is not None
                else None
            )
            while True:
                direct_exec_attempt += 1
                state_now = _read_json_file(infra_dir_abs / DIRECT_ITEM_LOOP_STATE_FILENAME)
                cursor_now = (
                    int(state_now.get("cursor_index"))
                    if isinstance(state_now, dict) and isinstance(state_now.get("cursor_index"), int)
                    else direct_start_index
                )
                frozen_now = (
                    int(state_now.get("frozen_count"))
                    if isinstance(state_now, dict) and isinstance(state_now.get("frozen_count"), int)
                    else None
                )
                print(
                    "[infra][direct] run attempt "
                    f"{direct_exec_attempt} | frozen={frozen_now if frozen_now is not None else '?'}"
                    f"/{len(plan_items)} | cursor={cursor_now if cursor_now is not None else '?'} | "
                    f"plan_attempts_left={max(0, direct_plan_attempt_cap - direct_plan_attempt_used)}"
                )
                log_event(
                    run_id,
                    "infra_direct_item_stage_start",
                    {
                        "round": round_idx,
                        "task_id": round_task_id,
                        "plan_path": str(plan_rel),
                        "direct_exec_attempt": int(direct_exec_attempt),
                        "plan_attempts_used": int(direct_plan_attempt_used),
                        "plan_attempts_cap": int(direct_plan_attempt_cap),
                        "simulate_success": bool(infra_direct_simulate_success),
                        "chunk_item_limit": int(infra_direct_chunk_item_limit),
                        "chunk_line_limit": int(infra_direct_chunk_line_limit),
                        "max_items": (
                            int(infra_direct_max_items)
                            if infra_direct_max_items is not None
                            else None
                        ),
                        "start_index": int(direct_start_index) if direct_start_index is not None else None,
                        "cursor_index": int(cursor_now) if cursor_now is not None else None,
                        "frozen_count": int(frozen_now) if frozen_now is not None else None,
                        "statement_max_b_retries": int(infra_direct_statement_max_b_retries),
                        "proof_max_b_retries": int(infra_direct_proof_max_b_retries),
                        "proof_max_c_replans": int(infra_direct_proof_max_c_replans),
                    },
                )
                direct_ok = _run_direct_item_stage(
                    bench_file_rel=bench_file_rel,
                    plan_path=plan_abs,
                    task_id=round_task_id,
                    env_overrides=env_overrides,
                    simulate_success=bool(infra_direct_simulate_success),
                    chunk_item_limit=int(infra_direct_chunk_item_limit),
                    chunk_line_limit=int(infra_direct_chunk_line_limit),
                    max_items=(
                        int(infra_direct_max_items)
                        if infra_direct_max_items is not None
                        else None
                    ),
                    start_index=direct_start_index,
                    statement_max_b_retries=int(infra_direct_statement_max_b_retries),
                    proof_max_b_retries=int(infra_direct_proof_max_b_retries),
                    proof_max_c_replans=int(infra_direct_proof_max_c_replans),
                )
                # After first execution pass, always resume from loop state unless a blocker sets a new start index.
                direct_start_index = None

                if direct_ok:
                    log_event(
                        run_id,
                        "infra_direct_item_stage_end",
                        {
                            "round": round_idx,
                            "task_id": round_task_id,
                            "status": "ok",
                            "plan_path": str(plan_rel),
                            "direct_exec_attempt": int(direct_exec_attempt),
                            "plan_attempts_used": int(direct_plan_attempt_used),
                        },
                    )
                    _finish("ok", items=total_items)
                    return True

                blocker = _load_direct_blocker(
                    lean_root=lean_root,
                    infra_dir_abs=infra_dir_abs,
                )
                if blocker is None:
                    print("[infra] direct item stage failed, but blocker context is unavailable.")
                    _finish(
                        "failed",
                        reason=f"direct_item_stage_failed_no_blocker_context_round_{round_idx}",
                        items=total_items,
                    )
                    return False
                fail_index = int(blocker.get("failed_index"))
                fail_reason = str(blocker.get("reason", "unknown"))
                print(
                    "[infra][direct] blocked at "
                    f"index={fail_index} reason={fail_reason}; preparing suffix replan."
                )
                log_event(
                    run_id,
                    "infra_direct_item_stage_end",
                    {
                        "round": round_idx,
                        "task_id": round_task_id,
                        "status": "blocked",
                        "plan_path": str(plan_rel),
                        "direct_exec_attempt": int(direct_exec_attempt),
                        "blocked_index": int(fail_index),
                        "blocked_reason": fail_reason,
                        "report_path": blocker.get("report_path"),
                    },
                )

                suffix_applied = False
                suffix_fail_reason: str | None = None
                replan_signal = next_signal if isinstance(next_signal, dict) else missing_theory_signal
                while direct_plan_attempt_used < direct_plan_attempt_cap:
                    attempt_no = direct_plan_attempt_used + 1
                    candidate_items, attempt_reason = _attempt_direct_suffix_replan(
                        bench_file_rel=bench_file_rel,
                        missing_theory_signal=replan_signal,
                        task_id=round_task_id,
                        run_id=run_id,
                        round_idx=round_idx,
                        attempt_no=attempt_no,
                        max_attempts=direct_plan_attempt_cap,
                        infra_agent_settings=infra_agent_settings,
                        infra_dir_rel=infra_dir_rel,
                        infra_dir_abs=infra_dir_abs,
                        lean_root=lean_root,
                        plan_abs=plan_abs,
                        plan_rel=plan_rel,
                        before_items=plan_items,
                        fail_index=fail_index,
                        blocker=blocker,
                        public_api_cap=infra_public_api_cap,
                        external_dependency_labels=prior_labels,
                        forbidden_labels=prior_labels,
                        existing_public_count=prior_public_count,
                    )
                    direct_plan_attempt_used += 1
                    if candidate_items is None:
                        suffix_fail_reason = attempt_reason
                        print(
                            "[infra][direct] suffix replan attempt failed: "
                            f"{attempt_reason} (used {direct_plan_attempt_used}/{direct_plan_attempt_cap})"
                        )
                        continue

                    _write_plan_file(plan_abs, candidate_items)
                    plan_items = candidate_items
                    all_items = _load_existing_plan_items(infra_dir_abs)
                    total_items = len(all_items)
                    _write_public_api_json(all_items, path_abs=lean_root / _public_api_path(bench_file_rel))
                    log_event(
                        run_id,
                        "infra_direct_suffix_replan_applied",
                        {
                            "round": round_idx,
                            "attempt": attempt_no,
                            "failed_index": fail_index,
                            "plan_path": str(plan_rel),
                            "items": len(candidate_items),
                            "plan_attempts_used": int(direct_plan_attempt_used),
                            "plan_attempts_cap": int(direct_plan_attempt_cap),
                        },
                    )
                    print(
                        "[infra][direct] suffix updated; "
                        f"resuming direct-item from index={fail_index}."
                    )
                    direct_start_index = int(fail_index)
                    suffix_applied = True
                    break

                if suffix_applied:
                    continue

                print(
                    "[infra][direct] plan attempts exhausted while fixing blocked item; "
                    "exiting for manual debug."
                )
                _finish(
                    "failed",
                    reason=(
                        f"direct_item_plan_attempts_exhausted_round_{round_idx}"
                        + (f":{suffix_fail_reason}" if suffix_fail_reason else "")
                    ),
                    items=total_items,
                )
                return False

        item_files = [Path(it["target_file"]) for it in plan_items]
        prior_label_to_file: dict[str, Path] = {
            str(it.get("label", "")).strip(): Path(str(it.get("target_file", "")))
            for it in prior_items
            if isinstance(it, dict)
            and isinstance(it.get("label"), str)
            and str(it.get("label", "")).strip()
            and isinstance(it.get("target_file"), str)
            and str(it.get("target_file", "")).strip()
        }
        draft_imports_by_file = _build_draft_imports_by_file(
            plan_items,
            external_label_to_file=prior_label_to_file,
        )
        for rel in sorted(set(item_files)):
            file_draft_imports = draft_imports_by_file.get(rel, [])
            _ensure_file_container(lean_root / rel, draft_imports=file_draft_imports)
            log_event(
                run_id,
                "infra_file_seeded",
                {"round": round_idx, "file": str(rel), "draft_imports": file_draft_imports},
            )

        all_item_files = [Path(it["target_file"]) for it in all_items if isinstance(it.get("target_file"), str)]
        _write_infra_entry_file(bench_file_rel=bench_file_rel, item_files=all_item_files, lean_root=lean_root)
        _ensure_bench_imports_infra(bench_file_rel=bench_file_rel, lean_root=lean_root)

        if not _run_statement_stage(
            plan_path=plan_abs,
            project=project,
            task_id=round_task_id,
            max_b_retries=statement_max_b_retries,
            env_overrides=env_overrides,
        ):
            print("[infra] statement stage failed.")
            _finish("failed", reason=f"statement_stage_failed_round_{round_idx}", items=total_items)
            return False
        if not _run_proof_stage(
            plan_path=plan_abs,
            project=project,
            task_id=round_task_id,
            max_b_retries=max_b_retries,
            max_c_replans=max_c_replans,
            env_overrides=env_overrides,
        ):
            print("[infra] proof stage failed.")
            _finish("failed", reason=f"proof_stage_failed_round_{round_idx}", items=total_items)
            return False

        infra_final_files = sorted({p.relative_to(lean_root) for p in infra_dir_abs.rglob("*.lean") if p.is_file()})
        print(f"[infra] final sweep on {len(infra_final_files)} infra files (round {round_idx})...")
        log_event(
            run_id,
            "infra_final_sweep_start",
            {"round": round_idx, "file_count": len(infra_final_files), "infra_dir": str(infra_dir_rel)},
        )
        final_ok, missing_signal = _run_final_stage_on_files(
            files=infra_final_files,
            lean_root=lean_root,
            env_overrides=env_overrides,
            task_id=task_id,
            round_idx=round_idx,
        )
        if final_ok:
            log_event(
                run_id,
                "infra_final_sweep_end",
                {"round": round_idx, "status": "ok", "file_count": len(infra_final_files)},
            )
            _finish("ok", items=total_items)
            return True

        log_event(
            run_id,
            "infra_final_sweep_end",
            {
                "round": round_idx,
                "status": "failed",
                "file_count": len(infra_final_files),
                "missing_theory_signal_present": bool(missing_signal),
            },
        )

        if missing_signal is None:
            print("[infra] final stage failed (non-missing-theory blocker).")
            _finish("failed", reason=f"infra_final_stage_failed_round_{round_idx}", items=total_items)
            return False

        sig_fp = _signal_fingerprint(missing_signal)
        if sig_fp in seen_blockers:
            print("[infra] missing-theory blocker repeated; stopping expansion to avoid spin.")
            _finish("failed", reason=f"infra_expand_stagnated_round_{round_idx}", items=total_items)
            return False
        seen_blockers.add(sig_fp)

        print(
            f"[infra] missing theory remains after round {round_idx}; "
            f"launching expansion round {next_generate_round}."
        )
        next_signal = missing_signal
        continue

    _finish("failed", reason="infra_expand_unexpected_exit", items=total_items)
    return False


def main() -> None:
    parser = argparse.ArgumentParser(description="Run infra sub-pipeline for a single bench file.")
    parser.add_argument("--bench-file", type=Path, required=True, help="Bench Lean file under Question_bench/.")
    parser.add_argument("--signal-file", type=Path, required=True, help="Path to missing-theory signal JSON.")
    parser.add_argument("--infra-public-api-cap", type=int, default=30, help="Max public API items (default: 30).")
    parser.add_argument(
        "--infra-plan-generate-attempts",
        type=int,
        default=INFRA_PLAN_MAX_ATTEMPTS,
        help=f"Max plan-agent retries when generating initial plan JSON (default: {INFRA_PLAN_MAX_ATTEMPTS}).",
    )
    parser.add_argument(
        "--infra-plan-check-rounds",
        type=int,
        default=INFRA_PLAN_CHECK_ROUNDS_DEFAULT,
        help=(
            "Max plan check+auto-fix rounds "
            f"(default: {INFRA_PLAN_CHECK_ROUNDS_DEFAULT}; set 0 to disable)."
        ),
    )
    parser.add_argument(
        "--infra-expand-max-rounds",
        type=int,
        default=INFRA_EXPAND_MAX_ROUNDS_DEFAULT,
        help=(
            "When infra final sweep still reports `failed_missing_theory`, run up to N expansion rounds "
            "(each round writes `infra_plan_expand_<k>.json`; old `infra_plan_expend_<k>.json` is still accepted). "
            f"Default: {INFRA_EXPAND_MAX_ROUNDS_DEFAULT}; set 0 to disable expansion."
        ),
    )
    parser.add_argument(
        "--infra-agent-config",
        type=Path,
        default=None,
        help=(
            "Path to a TOML file controlling infra plan/check agent model and reasoning. "
            "Default: use $INFRA_AGENT_CONFIG_FILE, else repo `agent_configs/infra_agents.toml` if present."
        ),
    )
    parser.add_argument(
        "--infra-exec-mode",
        type=str,
        choices=["legacy", "direct_item"],
        default="direct_item",
        help=(
            "Infra execution mode: `direct_item` runs the direct item scaffold pipeline, "
            "`legacy` runs statement->proof->final."
        ),
    )
    parser.add_argument(
        "--infra-direct-simulate-success",
        action="store_true",
        help=(
            "When using --infra-exec-mode direct_item, simulate successful item "
            "execution and append placeholder declarations into GeneratedPrefix."
        ),
    )
    parser.add_argument(
        "--infra-direct-chunk-item-limit",
        type=int,
        default=20,
        help="When using direct_item mode: max items per GeneratedPrefix chunk (default: 20).",
    )
    parser.add_argument(
        "--infra-direct-chunk-line-limit",
        type=int,
        default=1800,
        help="When using direct_item mode: soft max lines per GeneratedPrefix chunk (default: 1800).",
    )
    parser.add_argument(
        "--infra-direct-max-items",
        type=int,
        default=None,
        help="When using direct_item mode: process at most this many items in this run.",
    )
    parser.add_argument(
        "--infra-direct-start-index",
        type=int,
        default=None,
        help="When using direct_item mode: override cursor start index.",
    )
    parser.add_argument(
        "--infra-direct-statement-max-b-retries",
        type=int,
        default=3,
        help="When using direct_item mode: statement stage max Agent B retries per item (default: 3).",
    )
    parser.add_argument(
        "--infra-direct-proof-max-b-retries",
        type=int,
        default=3,
        help="When using direct_item mode: proof stage max Agent B retries per item (default: 3).",
    )
    parser.add_argument(
        "--infra-direct-proof-max-c-replans",
        type=int,
        default=1,
        help="When using direct_item mode: proof stage max Agent C replans per item (default: 1).",
    )
    parser.add_argument("--max-b-retries", type=int, default=3)
    parser.add_argument(
        "--infra-statement-max-b-retries",
        type=int,
        default=None,
        help=(
            "Max retries for Agent B in infra statement stage only. "
            "Default: use --max-b-retries."
        ),
    )
    parser.add_argument("--max-c-replans", type=int, default=1)
    args = parser.parse_args()

    signal = json.loads(args.signal_file.read_text(encoding="utf-8"))
    ok = run_infra_pipeline(
        bench_file=args.bench_file,
        missing_theory_signal=signal,
        infra_public_api_cap=int(args.infra_public_api_cap),
        max_b_retries=int(args.max_b_retries),
        infra_statement_max_b_retries=(
            int(args.infra_statement_max_b_retries)
            if args.infra_statement_max_b_retries is not None
            else None
        ),
        max_c_replans=int(args.max_c_replans),
        infra_plan_generate_attempts=int(args.infra_plan_generate_attempts),
        max_plan_check_rounds=int(args.infra_plan_check_rounds),
        infra_expand_max_rounds=int(args.infra_expand_max_rounds),
        infra_agent_config=args.infra_agent_config,
        infra_exec_mode=str(args.infra_exec_mode),
        infra_direct_simulate_success=bool(args.infra_direct_simulate_success),
        infra_direct_chunk_item_limit=int(args.infra_direct_chunk_item_limit),
        infra_direct_chunk_line_limit=int(args.infra_direct_chunk_line_limit),
        infra_direct_max_items=(
            int(args.infra_direct_max_items)
            if args.infra_direct_max_items is not None
            else None
        ),
        infra_direct_start_index=(
            int(args.infra_direct_start_index)
            if args.infra_direct_start_index is not None
            else None
        ),
        infra_direct_statement_max_b_retries=int(args.infra_direct_statement_max_b_retries),
        infra_direct_proof_max_b_retries=int(args.infra_direct_proof_max_b_retries),
        infra_direct_proof_max_c_replans=int(args.infra_direct_proof_max_c_replans),
    )
    raise SystemExit(0 if ok else 1)


if __name__ == "__main__":
    main()
