from __future__ import annotations

import argparse
import hashlib
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from .agent_settings import resolve_stage_agents_settings
from .codex_client import (
    build_item_statement_agent_b_prompt,
    CodexCallResult,
    ITEM_STATEMENT_AGENT_A_PROMPT,
    ITEM_STATEMENT_AGENT_C_SEMANTIC_CHECK_PROMPT,
    run_codex,
    _agent_extra_args,
    _assemble_prompt,
)
from .config import LEAN_ROOT, ROOT
from .label_blocks import extract_label_main_declaration_snippet
from .lean_runner import lake_env_lean
from .log_utils import build_log_filename, file_snapshot, snapshot_delta, slugify
from .metrics import finish_run, log_event, start_run
from .protocol import STATEMENT_SEMANTIC_CHECK_END, STATEMENT_SEMANTIC_CHECK_START, extract_marked_json
from .state import load_state, save_state


@dataclass(frozen=True, slots=True)
class ItemFile:
    index: int
    label: str
    env: str
    content: str
    nl_answer: str | None
    dependencies: list[Any]
    notes: Any | None
    target_file: str | None
    raw: dict[str, Any]


def _sha256_file(path: Path) -> str:
    return hashlib.sha256(path.read_bytes()).hexdigest()


def _load_item_file(data_file: Path) -> list[ItemFile]:
    raw = json.loads(data_file.read_text(encoding="utf-8"))
    if not isinstance(raw, list):
        raise ValueError(f"Expected a JSON list at {data_file}, got {type(raw).__name__}")
    items: list[ItemFile] = []
    for pos, entry in enumerate(raw):
        if not isinstance(entry, dict):
            continue
        label = entry.get("label")
        content = entry.get("content")
        if not isinstance(label, str) or not label.strip():
            raise ValueError(f"Missing/invalid label at position {pos}: {entry!r}")
        if not isinstance(content, str) or not content.strip():
            raise ValueError(f"Missing/invalid content for label={label} at position {pos}")
        idx = entry.get("index")
        if not isinstance(idx, int):
            idx = pos + 1
        deps = entry.get("dependencies")
        if not isinstance(deps, list):
            deps = []
        env = entry.get("env")
        if not isinstance(env, str) or not env.strip():
            env = "thm"
        proof = entry.get("proof")
        nl_answer = proof.strip() if isinstance(proof, str) and proof.strip() else None
        notes = entry.get("notes")
        if notes is None:
            notes = entry.get("note")
        target_file = entry.get("target_file")
        if not isinstance(target_file, str) or not target_file.strip():
            target_file = None
        items.append(
            ItemFile(
                index=int(idx),
                label=label.strip(),
                env=env.strip(),
                content=content,
                nl_answer=nl_answer,
                dependencies=deps,
                notes=notes,
                target_file=target_file.strip() if target_file else None,
                raw=entry,
            )
        )
    items.sort(key=lambda it: it.index)
    return items


def _non_sorry_warning_blocks(lean_output: str) -> list[str]:
    """
    Extract multi-line warning blocks that do not mention 'sorry'.
    """
    if not lean_output:
        return []

    lines = lean_output.splitlines()
    blocks: list[str] = []
    i = 0
    while i < len(lines):
        line = lines[i]
        lowered = line.lower()
        if "warning" in lowered:
            block_lines = [line]
            block_lines_lower = [lowered]
            i += 1
            while i < len(lines):
                next_line = lines[i]
                next_lower = next_line.lower()
                if "warning" in next_lower:
                    break
                block_lines.append(next_line)
                block_lines_lower.append(next_lower)
                i += 1
            block_text_lower = "\n".join(block_lines_lower)
            if "sorry" not in block_text_lower:
                blocks.append("\n".join(block_lines).strip())
            continue
        i += 1

    return blocks


def _first_error_line(output: str | None) -> str | None:
    if not output:
        return None
    for raw in output.splitlines():
        line = raw.strip()
        if not line:
            continue
        if "error:" in line.lower():
            return line
    # fallback to first non-empty line
    for raw in output.splitlines():
        line = raw.strip()
        if line:
            return line
    return None


def _maybe_report_math_blocker(*, stdout: str, agent: str, label: str) -> None:
    """
    Best-effort terminal hint: if an agent reports "mathematically unprovable/false", surface it prominently.
    """
    if not stdout:
        return
    lowered = stdout.lower()
    if "mathematically unprovable" in lowered or "mathematically false" in lowered:
        print(f"[REPORT] Agent {agent} indicates label={label} may be mathematically unprovable/false.")


def _ensure_item_file_container(abs_file: Path) -> None:
    abs_file.parent.mkdir(parents=True, exist_ok=True)
    if abs_file.exists():
        return
    abs_file.write_text(
        "\n".join(
            [
                "import Mathlib",
                "",
                "-- Declarations for this item will be appended below by the statement pipeline.",
                "",
            ]
        ),
        encoding="utf-8",
    )


def _run_statement_agent_a_for_item(
    *,
    item: ItemFile,
    target_rel: Path,
    task_id: str,
    model: str | None,
    reasoning_effort: str | None,
    log_dir: Path,
    stage: str,
) -> CodexCallResult:
    print(
        f"[Agent A] [item_statement] start idx={item.index} label={item.label} "
        f"target={target_rel} model={model or 'default'}/{reasoning_effort or 'default'}"
    )
    meta = {
        "item": {
            "index": item.index,
            "label": item.label,
            # The statement prompt expects `env` and `number_components`; in this mode we
            # synthesize stable defaults (meaning is carried primarily by `content`).
            "env": item.env or "thm",
            "number_components": [0, 0, item.index],
            "content": item.content,
            "dependencies": item.dependencies,
            "context": {
                "task_id": task_id,
                "source_mode": "item_per_file",
            },
            "nl_answer": item.nl_answer,
            "notes": item.notes,
        },
        "target_file": str(target_rel),
    }
    prompt = _assemble_prompt(ITEM_STATEMENT_AGENT_A_PROMPT, meta, extra_instructions=None)
    log_name = build_log_filename("item_statement", "agent_a", task_id, f"idx{item.index}", item.label)
    return run_codex(
        prompt,
        extra_args=_agent_extra_args(model=model, reasoning_effort=reasoning_effort) or None,
        log_name=log_name,
        log_dir=log_dir,
        cwd=LEAN_ROOT,
        stage=stage,
        log_meta={
            "pipeline": "item_statement",
            "agent": "A",
            "task_id": task_id,
            "item_index": item.index,
            "label": item.label,
            "target_file": str(target_rel),
            "model": model,
            "reasoning_effort": reasoning_effort,
        },
    )
def _run_statement_agent_b_for_item(
    *,
    item: ItemFile,
    target_rel: Path,
    task_id: str,
    error_log: str,
    model: str | None,
    reasoning_effort: str | None,
    log_dir: Path,
    extra_instructions: str | None = None,
    stage: str,
) -> CodexCallResult:
    print(
        f"[Agent B] [item_statement] start idx={item.index} label={item.label} "
        f"target={target_rel} model={model or 'default'}/{reasoning_effort or 'default'}"
    )
    prompt = build_item_statement_agent_b_prompt(
        target_rel,
        error_log,
        item_index=item.index,
        label=item.label,
        item_context={"task_id": task_id, "source_mode": "item_per_file"},
        item_dependencies=item.dependencies,
        item_notes=item.notes,
        extra_instructions=extra_instructions,
    )
    log_name = build_log_filename("item_statement", "agent_b", task_id, f"idx{item.index}", target_rel.as_posix())
    return run_codex(
        prompt,
        extra_args=_agent_extra_args(model=model, reasoning_effort=reasoning_effort) or None,
        log_name=log_name,
        log_dir=log_dir,
        cwd=LEAN_ROOT,
        stage=stage,
        log_meta={
            "pipeline": "item_statement",
            "agent": "B",
            "task_id": task_id,
            "item_index": item.index,
            "label": item.label,
            "target_file": str(target_rel),
            "model": model,
            "reasoning_effort": reasoning_effort,
        },
    )


def _run_statement_agent_c_semantic_check_for_item(
    *,
    item: ItemFile,
    target_rel: Path,
    task_id: str,
    formal_snippet: str,
    decl_info: dict[str, Any] | None,
    model: str | None,
    reasoning_effort: str | None,
    log_dir: Path,
    stage: str,
) -> CodexCallResult:
    print(
        f"[Agent C] [item_statement] semantic_check start idx={item.index} label={item.label} "
        f"target={target_rel} model={model or 'default'}/{reasoning_effort or 'default'}"
    )
    if not ITEM_STATEMENT_AGENT_C_SEMANTIC_CHECK_PROMPT.strip():
        raise RuntimeError("Missing prompts/statement/agent_c_semantic_check_prompt.txt")
    meta = {
        "label": item.label,
        "env": None,
        "content": item.content,
        "target_file": str(target_rel),
        "formal_snippet": formal_snippet,
        "decl_info": decl_info,
        "dependencies": item.dependencies,
        "context": {
            "task_id": task_id,
            "source_mode": "item_per_file",
        },
    }
    prompt = _assemble_prompt(ITEM_STATEMENT_AGENT_C_SEMANTIC_CHECK_PROMPT, meta, extra_instructions=None)
    log_name = build_log_filename("item_statement", "agent_c_semantic_check", task_id, f"idx{item.index}", item.label)
    return run_codex(
        prompt,
        extra_args=_agent_extra_args(model=model, reasoning_effort=reasoning_effort) or None,
        log_name=log_name,
        log_dir=log_dir,
        cwd=LEAN_ROOT,
        stage=stage,
        log_meta={
            "pipeline": "item_statement",
            "agent": "C",
            "mode": "semantic_check",
            "task_id": task_id,
            "item_index": item.index,
            "label": item.label,
            "target_file": str(target_rel),
            "model": model,
            "reasoning_effort": reasoning_effort,
        },
    )


def _semantic_check_and_maybe_set_failure(
    *,
    run_id: str,
    agent_settings: dict[str, Any],
    item: ItemFile,
    target_abs: Path,
    target_rel: Path,
    task_id: str,
    semantic_check: bool,
    semantic_check_policy: str,
    log_dir: Path,
    total_tokens_ref: list[int],
    stage: str,
) -> tuple[str | None, str | None, str | None]:
    """
    Run Agent C semantic check for the main labeled declaration in `target_abs`.
    Returns (failed_mode, err, b_extra_instructions). When semantic check passes/disabled, returns (None, None, None).
    """
    if not semantic_check:
        return None, None, None

    try:
        file_text = target_abs.read_text(encoding="utf-8")
    except Exception:
        file_text = ""

    snippet, decl_info = extract_label_main_declaration_snippet(text=file_text, label=item.label, max_lines=180)
    if not snippet:
        # Fallback: provide the full file (truncated) and let Agent C locate the relevant declaration by label/name.
        max_chars = 20000
        snippet = (file_text or "")[:max_chars]
        decl_info = {"label_search": "full_file", "label_found": item.label in (file_text or "")}
        log_event(
            run_id,
            "item_statement_semantic_check",
            {
                "index": item.index,
                "label": item.label,
                "task_id": task_id,
                "status": "fallback_full_file",
                "policy": semantic_check_policy,
            },
        )

    c_res = _run_statement_agent_c_semantic_check_for_item(
        item=item,
        target_rel=target_rel,
        task_id=task_id,
        formal_snippet=snippet,
        decl_info=decl_info,
        model=agent_settings["c"].model,
        reasoning_effort=agent_settings["c"].reasoning_effort,
        log_dir=log_dir,
        stage=stage,
    )
    total_tokens_ref[0] += c_res.tokens_used or 0

    report, raw = extract_marked_json(c_res.stdout or "", STATEMENT_SEMANTIC_CHECK_START, STATEMENT_SEMANTIC_CHECK_END)
    status = None
    reason = None
    suggested = None
    if isinstance(report, dict):
        status = str(report.get("status") or "").strip().lower() or None
        reason = str(report.get("reason") or "").strip() or None
        suggested = str(report.get("suggested_revision") or "").strip() or None

    log_event(
        run_id,
        "item_statement_semantic_check",
        {
            "index": item.index,
            "label": item.label,
            "task_id": task_id,
            "status": status or "unparsed",
            "reason": reason,
            "tokens_used": c_res.tokens_used,
            "log_path": str(c_res.log_path) if c_res.log_path else None,
            "policy": semantic_check_policy,
        },
    )

    if status == "ok":
        print("[Semantic check] OK.")
        return None, None, None

    msg = (
        f"Statement semantic check flagged potential drift (status={status or 'unparsed'}). "
        + (reason or "No reason provided.")
    )
    if semantic_check_policy == "warn":
        print(f"[Semantic check] WARNING: {msg}")
        return None, None, None
    if semantic_check_policy == "fail":
        return "semantic", msg, msg

    b_extra = "\n".join(
        part
        for part in [
            "SEMANTIC CHECK FAILED (item-per-file statement stage): Please revise the Lean statement to match the intended meaning.",
            "Focus on: missing hypotheses from 'previous setup/subgoal', over-generalization (∀Φ), wrong quantifiers/directions/objects.",
            ("Suggested revision:\n" + suggested) if suggested else None,
            "Informal statement (authoritative):\n" + (item.content or ""),
            "Current Lean snippet:\n" + snippet,
            ("Raw semantic-check JSON:\n" + raw) if raw else None,
        ]
        if part
    )
    return "semantic", msg, b_extra
def main() -> None:
    parser = argparse.ArgumentParser(
        description="Run STATEMENT stage in item-per-file mode (one JSON item -> one Lean file)."
    )
    parser.add_argument("--project", type=str, default="P2512_19197", help="Lean subdir under M2F/ (default: P2512_19197).")
    parser.add_argument("--data-file", type=Path, required=True, help="Path to a JSON list of items.")
    parser.add_argument(
        "--stage",
        type=str,
        default="statement",
        help="Codex stage for AGENTS.md swapping (default: statement). Use 'infra' for infra pipelines.",
    )
    parser.add_argument(
        "--task-id",
        type=str,
        default=None,
        help="Optional task id (defaults to the data file stem). Used for log partitioning only.",
    )
    parser.add_argument("--start-index", type=int, default=None, help="Start processing from this item index (overrides saved state).")
    parser.add_argument("--only-label", type=str, default=None, help="Process only this item label (e.g. n000001).")
    parser.add_argument("--max-items", type=int, default=None, help="Process at most this many items.")
    parser.add_argument("--max-b-retries", type=int, default=3, help="Max retries for Agent B when Lean still fails (default: 3).")
    parser.add_argument(
        "--clean-warnings-with-agent-b",
        action="store_true",
        help="If set, invoke Agent B to remove non-sorry warnings (default: false).",
    )
    parser.add_argument(
        "--statement-agent-config",
        type=Path,
        default=None,
        help="Optional TOML controlling per-agent model/reasoning for STATEMENT stage.",
    )
    default_semantic_check = os.getenv("STATEMENT_SEMANTIC_CHECK", "").strip().lower() in {
        "1",
        "true",
        "yes",
        "on",
        "enable",
        "enabled",
    }
    parser.add_argument(
        "--semantic-check",
        dest="semantic_check",
        action="store_true",
        default=default_semantic_check,
        help="Run an extra Agent C pass to check statement semantics (detect drift).",
    )
    parser.add_argument(
        "--no-semantic-check",
        dest="semantic_check",
        action="store_false",
        help="Disable semantic-check even if enabled by env var.",
    )
    parser.add_argument(
        "--semantic-check-policy",
        choices=["warn", "fix", "fail"],
        default=os.getenv("STATEMENT_SEMANTIC_CHECK_POLICY", "fix").strip().lower() or "fix",
        help="On semantic mismatch: warn (continue), fix (invoke Agent B), or fail (stop).",
    )
    args = parser.parse_args()

    project = args.project.strip()
    if not project:
        raise SystemExit("--project must be non-empty")
    task_id = (args.task_id or args.data_file.stem).strip()
    if not task_id:
        raise SystemExit("--task-id must be non-empty (or provide a non-empty data file stem)")

    # Keep logs partitioned per project/task_id under repo log/.
    project_slug = slugify(project, max_len=80)
    task_slug = slugify(task_id, max_len=120)
    logs_dir = ROOT / "log" / project_slug / "item_statement_logs" / task_slug
    agent_a_logs_dir = logs_dir / "agent_a"
    agent_b_logs_dir = logs_dir / "agent_b"
    agent_c_logs_dir = logs_dir / "agent_c"
    progress_file = logs_dir / "progress.json"
    agent_a_logs_dir.mkdir(parents=True, exist_ok=True)
    agent_b_logs_dir.mkdir(parents=True, exist_ok=True)
    agent_c_logs_dir.mkdir(parents=True, exist_ok=True)

    items = _load_item_file(args.data_file)
    if args.only_label:
        want = args.only_label.strip()
        items = [it for it in items if it.label == want]
        if not items:
            raise SystemExit(f"--only-label={want} not found in {args.data_file}")

    run_id = start_run(
        "item_statement",
        stage=1,
        name_tag=f"{project_slug}_{task_slug}",
        data_file=str(args.data_file),
        extra={"project": project, "task_id": task_id},
    )

    state = load_state(progress_file)
    plan_sha256 = _sha256_file(args.data_file)
    prev_plan_sha256 = str(state.get("plan_sha256", "") or "").strip()
    state_changed = False
    if prev_plan_sha256 and prev_plan_sha256 != plan_sha256:
        prev_next = int(state.get("next_index", 1) or 1)
        state["next_index"] = 1
        state_changed = True
        print(
            "[progress] plan hash changed; reset item_statement progress "
            f"from next_index={prev_next} to next_index=1."
        )
        log_event(
            run_id,
            "progress_reset_plan_changed",
            {
                "task_id": task_id,
                "progress_file": str(progress_file),
                "data_file": str(args.data_file),
                "previous_next_index": prev_next,
                "previous_plan_sha256": prev_plan_sha256,
                "plan_sha256": plan_sha256,
            },
        )
    if prev_plan_sha256 != plan_sha256:
        state["plan_sha256"] = plan_sha256
        state_changed = True
    if args.start_index is not None:
        state["next_index"] = int(args.start_index)
        state_changed = True
    if state_changed:
        save_state(state, progress_file, run_id=run_id)

    next_index = int(state.get("next_index", 1) or 1)
    processed = 0
    last_success = next_index - 1
    total_tokens = 0

    default_cfg = ROOT / "agent_configs/statement_agents.toml"
    agent_settings = resolve_stage_agents_settings(
        stage_prefix="STATEMENT",
        agent_keys=["a", "b", "c"],
        config_path=args.statement_agent_config,
        default_config_path=(default_cfg if default_cfg.exists() else None),
    ).agents

    stage = (args.stage or "statement").strip().lower()
    for item in items:
        if item.index < next_index:
            continue
        print(f"=== item-statement index={item.index} label={item.label} (task={task_id}) ===")

        if item.target_file:
            target_rel = Path(item.target_file)
        else:
            target_rel = Path(project) / f"{item.label}.lean"
        target_abs = LEAN_ROOT / target_rel
        _ensure_item_file_container(target_abs)

        before = file_snapshot(target_abs)
        log_event(
            run_id,
            "item_start",
            {"index": item.index, "label": item.label, "task_id": task_id, "target_file": str(target_rel)},
        )

        a_res = _run_statement_agent_a_for_item(
            item=item,
            target_rel=target_rel,
            task_id=task_id,
            model=agent_settings["a"].model,
            reasoning_effort=agent_settings["a"].reasoning_effort,
            log_dir=agent_a_logs_dir,
            stage=stage,
        )
        total_tokens += a_res.tokens_used or 0
        after_a = file_snapshot(target_abs)
        log_event(
            run_id,
            "agent_a_result",
            {
                "index": item.index,
                "label": item.label,
                "task_id": task_id,
                "code": a_res.code,
                "tokens_used": a_res.tokens_used,
                "log_path": str(a_res.log_path) if a_res.log_path else None,
                "stderr_snippet": (a_res.stderr[:1500] if a_res.stderr else None),
                "file_before": before,
                "file_after": after_a,
                "file_delta": snapshot_delta(before, after_a),
            },
        )
        if a_res.code != 0:
            print("Agent A failed; stopping.")
            log_event(
                run_id,
                "item_end",
                {
                    "index": item.index,
                    "label": item.label,
                    "task_id": task_id,
                    "status": "agent_a_failed",
                    "error_snippet": (a_res.stderr[:2000] if a_res.stderr else None),
                },
            )
            break
        _maybe_report_math_blocker(stdout=a_res.stdout, agent="A", label=item.label)

        try:
            text_after_a = target_abs.read_text(encoding="utf-8")
        except FileNotFoundError:
            text_after_a = ""
        if item.label not in text_after_a:
            print("Agent A produced no visible labeled declaration; stopping (to avoid false success).")
            log_event(
                run_id,
                "item_end",
                {
                    "index": item.index,
                    "label": item.label,
                    "task_id": task_id,
                    "status": "agent_a_noop",
                    "error_snippet": "Target file does not contain the item label after Agent A.",
                },
            )
            break

        # Compile check
        print(f"[lean check] running: lake env lean {target_rel} (cwd={LEAN_ROOT})")
        code, out, err = lake_env_lean(target_rel)
        lean_output = "\n".join(part for part in (err, out) if part)
        non_sorry_blocks = _non_sorry_warning_blocks(lean_output)
        log_event(
            run_id,
            "lean_check",
            {
                "index": item.index,
                "label": item.label,
                "phase": "post_agent_a",
                "code": code,
                "has_non_sorry_warnings": bool(non_sorry_blocks),
                "compiled_file": str(target_rel),
            },
        )

        need_b = False
        b_error_log = ""
        b_extra_instructions: str | None = None
        need_b_reason: str | None = None
        if code != 0:
            need_b = True
            b_error_log = lean_output
            summary = _first_error_line(lean_output) or "Lean failed after Agent A."
            need_b_reason = f"Lean failed after Agent A: {summary}"
        elif non_sorry_blocks and args.clean_warnings_with_agent_b:
            need_b = True
            b_error_log = "Lean produced the following non-sorry warnings. Please remove them:\n\n" + "\n\n".join(
                non_sorry_blocks
            )
            need_b_reason = "Lean produced non-sorry warnings after Agent A; cleaning with Agent B."

        if not need_b:
            semantic_tokens_ref = [0]
            sem_failed_mode, sem_err, sem_b_extra = _semantic_check_and_maybe_set_failure(
                run_id=run_id,
                agent_settings=agent_settings,
                item=item,
                target_abs=target_abs,
                target_rel=target_rel,
                task_id=task_id,
                semantic_check=bool(args.semantic_check),
                semantic_check_policy=str(args.semantic_check_policy),
                log_dir=agent_c_logs_dir,
                total_tokens_ref=semantic_tokens_ref,
                stage=stage,
            )
            total_tokens += semantic_tokens_ref[0]
            if sem_failed_mode is not None:
                if str(args.semantic_check_policy) == "fail":
                    print(sem_err or "Statement semantic check failed.")
                    log_event(
                        run_id,
                        "item_end",
                        {
                            "index": item.index,
                            "label": item.label,
                            "task_id": task_id,
                            "status": "semantic_failed",
                            "error_snippet": (sem_err[:2000] if sem_err else None),
                        },
                    )
                    break
                if sem_err:
                    print(sem_err)
                else:
                    print("Statement semantic check flagged drift.")
                need_b = True
                b_error_log = sem_err or "Statement semantic check flagged drift."
                b_extra_instructions = sem_b_extra
                summary = _first_error_line(sem_err) or "Statement semantic check flagged drift."
                need_b_reason = f"Semantic check flagged drift: {summary}"

        b_attempts = 0
        while need_b and b_attempts < int(args.max_b_retries):
            b_attempts += 1
            reason = need_b_reason or "Fixing issues after Agent A."
            print(f"{reason} Calling Agent B (attempt {b_attempts}/{args.max_b_retries})...")
            b_res = _run_statement_agent_b_for_item(
                item=item,
                target_rel=target_rel,
                task_id=task_id,
                error_log=b_error_log,
                model=agent_settings["b"].model,
                reasoning_effort=agent_settings["b"].reasoning_effort,
                log_dir=agent_b_logs_dir,
                extra_instructions=b_extra_instructions,
                stage=stage,
            )
            total_tokens += b_res.tokens_used or 0
            log_event(
                run_id,
                "agent_b_result",
                {
                    "index": item.index,
                    "label": item.label,
                    "task_id": task_id,
                    "attempt": b_attempts,
                    "code": b_res.code,
                    "tokens_used": b_res.tokens_used,
                    "log_path": str(b_res.log_path) if b_res.log_path else None,
                    "stderr_snippet": (b_res.stderr[:1500] if b_res.stderr else None),
                },
            )
            if b_res.code != 0:
                print("Agent B failed; stopping.")
                log_event(
                    run_id,
                    "item_end",
                    {
                        "index": item.index,
                        "label": item.label,
                        "task_id": task_id,
                        "status": "agent_b_failed",
                        "error_snippet": (b_res.stderr[:2000] if b_res.stderr else None),
                    },
                )
                code = 1
                break
            _maybe_report_math_blocker(stdout=b_res.stdout, agent="B", label=item.label)
            print(f"[lean check] running: lake env lean {target_rel} (cwd={LEAN_ROOT})")
            code, out, err = lake_env_lean(target_rel)
            lean_output = "\n".join(part for part in (err, out) if part)
            non_sorry_blocks = _non_sorry_warning_blocks(lean_output)
            log_event(
                run_id,
                "lean_check",
                {
                    "index": item.index,
                    "label": item.label,
                    "phase": f"post_agent_b_attempt{b_attempts}",
                    "code": code,
                    "has_non_sorry_warnings": bool(non_sorry_blocks),
                    "compiled_file": str(target_rel),
                },
            )
            if code == 0 and (not non_sorry_blocks or not args.clean_warnings_with_agent_b):
                semantic_tokens_ref = [0]
                sem_failed_mode, sem_err, sem_b_extra = _semantic_check_and_maybe_set_failure(
                    run_id=run_id,
                    agent_settings=agent_settings,
                    item=item,
                    target_abs=target_abs,
                    target_rel=target_rel,
                    task_id=task_id,
                    semantic_check=bool(args.semantic_check),
                    semantic_check_policy=str(args.semantic_check_policy),
                    log_dir=agent_c_logs_dir,
                    total_tokens_ref=semantic_tokens_ref,
                    stage=stage,
                )
                total_tokens += semantic_tokens_ref[0]
                if sem_failed_mode is not None:
                    if str(args.semantic_check_policy) == "fail":
                        print(sem_err or "Statement semantic check failed after Agent B.")
                        log_event(
                            run_id,
                            "item_end",
                            {
                                "index": item.index,
                                "label": item.label,
                                "task_id": task_id,
                                "status": "semantic_failed_after_b",
                                "error_snippet": (sem_err[:2000] if sem_err else None),
                            },
                        )
                        code = 1
                        break
                    need_b = True
                    b_error_log = sem_err or "Statement semantic check flagged drift."
                    b_extra_instructions = sem_b_extra
                    continue
                need_b = False
                break
            if code != 0:
                b_error_log = lean_output
                b_extra_instructions = None
            elif non_sorry_blocks and args.clean_warnings_with_agent_b:
                b_error_log = "Lean still reports these non-sorry warnings. Please remove them:\n\n" + "\n\n".join(
                    non_sorry_blocks
                )
                b_extra_instructions = None

        if code != 0:
            # Revert on failure? For now, keep changes and record failure.
            print("Item failed after Agent B; stopping.")
            log_event(
                run_id,
                "item_end",
                {
                    "index": item.index,
                    "label": item.label,
                    "task_id": task_id,
                    "status": "failed_after_b",
                    "error_snippet": (lean_output[:2000] if lean_output else None),
                },
            )
            break

        # Success
        state["next_index"] = item.index + 1
        save_state(state, progress_file, run_id=run_id)
        processed += 1
        last_success = item.index
        log_event(
            run_id,
            "item_end",
            {"index": item.index, "label": item.label, "task_id": task_id, "status": "ok"},
        )

        if args.max_items is not None and processed >= int(args.max_items):
            print(f"Reached max-items={args.max_items}, stopping batch.")
            break

    summary = {
        "pipeline": "item_statement",
        "stage": 1,
        "run_id": run_id,
        "project": project,
        "task_id": task_id,
        "data_file": str(args.data_file),
        "processed": processed,
        "last_success_index": last_success,
        "next_index": int(state.get("next_index", next_index)),
        "tokens_used_total": total_tokens,
        "paths": {
            "progress_file": str(progress_file),
            "logs_dir": str(logs_dir),
        },
    }
    finish_run(run_id, summary)


if __name__ == "__main__":
    main()
