from __future__ import annotations

import argparse
import difflib
import os
import time
from pathlib import Path

from .codex_client import run_agent_a, run_agent_b, run_agent_b_bookcheck
from .agent_settings import resolve_stage_agents_settings
from .codex_client import run_statement_agent_c_semantic_check
from .axiom_guard import (
    build_axiom_cleanup_instructions,
    find_axiom_decls,
    format_axiom_report,
)
from .config import DATA_PATH, LEAN_ROOT, METRICS_DIR, ROOT, STATEMENT_LOGS_DIR, resolve_item_target_file
from .book_maintainer import (
    compile_entry_for,
    ensure_book_exists,
    ensure_book_imports,
    ensure_chapter_imports,
    ensure_section_aggregate_exists,
)
from .lean_runner import lake_env_lean
from .loader import filter_items, load_items
from .label_blocks import (
    extract_label_main_declaration_snippet,
    find_label_main_decl_key_def_placeholder_violations,
)
from .log_utils import file_snapshot, snapshot_delta
from .metrics import finish_run, log_event, start_run
from .protocol import (
    STATEMENT_SEMANTIC_CHECK_END,
    STATEMENT_SEMANTIC_CHECK_START,
    extract_marked_json,
)
from .state import STATEMENT_PROGRESS_FILE, load_state, save_state


def _is_import_preamble_line(line: str) -> bool:
    s = (line or "").strip()
    if not s:
        return True
    if s.startswith("--"):
        return True
    return s.startswith("import ")


def _import_preamble_end(lines: list[str]) -> int:
    """
    Return the end index (exclusive) of the leading import preamble region.
    We allow blank lines and `--` comments interspersed, but stop at the first
    non-import / non-comment / non-blank line.
    """
    saw_import = False
    for i, ln in enumerate(lines):
        s = (ln or "").strip()
        if not s:
            continue
        if s.startswith("--"):
            continue
        if s.startswith("import "):
            saw_import = True
            continue
        # First non-preamble line.
        return i if saw_import else 0
    return len(lines) if saw_import else 0


def _merge_intervals(intervals: list[tuple[int, int]]) -> list[tuple[int, int]]:
    if not intervals:
        return []
    intervals = sorted(intervals)
    merged: list[tuple[int, int]] = []
    cur_s, cur_e = intervals[0]
    for s, e in intervals[1:]:
        if s <= cur_e:
            cur_e = max(cur_e, e)
        else:
            merged.append((cur_s, cur_e))
            cur_s, cur_e = s, e
    merged.append((cur_s, cur_e))
    return merged


def _maybe_report_math_blocker(*, stdout: str, agent: str, label: str) -> None:
    """
    Best-effort terminal hint: if an agent reports "mathematically unprovable/false", surface it prominently.
    """
    if not stdout:
        return
    lowered = stdout.lower()
    if "mathematically unprovable" in lowered or "mathematically false" in lowered:
        print(f"[REPORT] Agent {agent} indicates label={label} may be mathematically unprovable/false.")


def _semantic_check_and_maybe_set_failure(
    *,
    run_id: str,
    agent_settings,
    item,
    lean_file_abs: Path,
    compile_file_rel: Path,
    args,
    total_tokens_used_ref: list[int],
) -> tuple[str | None, str | None, str | None]:
    """
    Run Agent C semantic check for the statement and return (failed_mode, err, b_extra_instructions).
    Returns (None, None, None) when semantic check is disabled or passes.
    """
    if not getattr(args, "semantic_check", False):
        return None, None, None

    try:
        file_text = lean_file_abs.read_text(encoding="utf-8")
    except Exception:
        file_text = ""

    snippet, decl_info = extract_label_main_declaration_snippet(text=file_text, label=item.label, max_lines=180)
    if not snippet:
        msg = (
            "Statement semantic check failed: could not locate the main declaration whose docstring starts with "
            f"label={item.label!r}."
        )
        log_event(
            run_id,
            "statement_semantic_check",
            {"index": item.index, "label": item.label, "status": "no_decl_found", "policy": args.semantic_check_policy},
        )
        if args.semantic_check_policy == "warn":
            print(f"[Semantic check] WARNING: {msg}")
            return None, None, None
        return "semantic", msg, msg

    c_res = run_statement_agent_c_semantic_check(
        label=item.label,
        env=item.env,
        content=item.content,
        target_file=compile_file_rel,
        formal_snippet=snippet,
        decl_info=decl_info,
        dependencies=item.dependencies,
        context=item.context,
        model=agent_settings.agents["C"].model,
        reasoning_effort=agent_settings.agents["C"].reasoning_effort,
    )
    total_tokens_used_ref[0] += c_res.tokens_used or 0

    report, raw = extract_marked_json(c_res.stdout or "", STATEMENT_SEMANTIC_CHECK_START, STATEMENT_SEMANTIC_CHECK_END)
    status = None
    reason = None
    suggested = None
    if isinstance(report, dict):
        status = str(report.get("status") or "").strip().lower() or None
        reason = str(report.get("reason") or "").strip() or None
        suggested = str(report.get("suggested_revision") or "").strip() or None

    log_event(
        run_id,
        "statement_semantic_check",
        {
            "index": item.index,
            "label": item.label,
            "status": status or "unparsed",
            "reason": reason,
            "tokens_used": c_res.tokens_used,
            "log_path": str(c_res.log_path) if c_res.log_path else None,
            "policy": args.semantic_check_policy,
        },
    )

    if status == "ok":
        print("[Semantic check] OK.")
        return None, None, None

    msg = (
        f"Statement semantic check flagged potential drift (status={status or 'unparsed'}). "
        + (reason or "No reason provided.")
    )
    if args.semantic_check_policy == "warn":
        print(f"[Semantic check] WARNING: {msg}")
        return None, None, None
    if args.semantic_check_policy == "fail":
        return "semantic", msg, msg

    # policy=fix
    b_extra = "\n".join(
        part
        for part in [
            "SEMANTIC CHECK FAILED (statement stage): Please revise the Lean statement to match the intended meaning.",
            "Focus on: missing hypotheses from 'previous setup/subgoal', over-generalization (∀Φ), wrong quantifiers/directions/objects.",
            ("Suggested revision:\n" + suggested) if suggested else None,
            "Informal statement (authoritative):\n" + (item.content or ""),
            "Current Lean snippet:\n" + snippet,
            ("Raw semantic-check JSON:\n" + raw) if raw else None,
        ]
        if part
    )
    return "semantic", msg, b_extra


def _key_def_placeholders_and_maybe_set_failure(
    *,
    run_id: str,
    item,
    lean_file_abs: Path,
    require_concrete_key_defs: bool,
) -> tuple[str | None, str | None, str | None]:
    """
    Enforce: key local definitions used by the main labeled declaration must not be `:= sorry`.
    Returns (failed_mode, err, b_extra_instructions) or (None, None, None).
    """
    if not require_concrete_key_defs:
        return None, None, None

    try:
        file_text = lean_file_abs.read_text(encoding="utf-8")
    except Exception:
        file_text = ""

    violations = find_label_main_decl_key_def_placeholder_violations(text=file_text, label=item.label)
    if not violations:
        return None, None, None

    reason = (
        "Key-definition placeholder check failed: the main statement depends on local `def`/`abbrev` "
        "declarations that are still `:= sorry`."
    )
    details = "\n".join(f"- {v}" for v in violations)
    msg = reason + "\n\n" + details
    b_extra = "\n".join(
        [
            "KEY-DEF PLACEHOLDER CHECK FAILED (statement stage): revise the current item so the main statement has meaningful semantics.",
            "Do not leave key local `def`/`abbrev` used by the main theorem/predicate as `:= sorry`.",
            "Keep theorem/lemma proofs as `:= sorry` in statement stage, but make key definitions concrete terms.",
            "Violations:",
            details,
            "Informal statement (authoritative):",
            item.content or "",
        ]
    )
    log_event(
        run_id,
        "statement_key_def_check",
        {
            "index": item.index,
            "label": item.label,
            "status": "drift",
            "violations": violations,
        },
    )
    return "key_def", msg, b_extra


def _compute_agent_b_allowed_intervals(*, before_text: str, after_text: str) -> list[tuple[int, int]]:
    """
    Agent B may only edit:
    - the import preamble (for necessary imports)
    - declaration blocks that were inserted/modified by Agent A for this item.

    We approximate the "Agent A touched region" as the set of line intervals in the
    post-Agent-A file that come from 'replace'/'insert' opcodes when diffing
    pre-Agent-A vs post-Agent-A.
    """
    before_lines = before_text.splitlines()
    after_lines = after_text.splitlines()
    sm = difflib.SequenceMatcher(a=before_lines, b=after_lines)
    intervals: list[tuple[int, int]] = []
    for tag, _i1, _i2, j1, j2 in sm.get_opcodes():
        if tag in {"replace", "insert"} and j1 != j2:
            intervals.append((j1, j2))

    # Expand slightly to allow harmless whitespace/context edits around the inserted block.
    expanded: list[tuple[int, int]] = []
    for s, e in _merge_intervals(intervals):
        expanded.append((max(0, s - 2), min(len(after_lines), e + 2)))
    return _merge_intervals(expanded)


def _agent_b_edit_scope_violations(
    *,
    base_after_a_text: str,
    after_b_text: str,
    allowed_intervals_in_after_a: list[tuple[int, int]],
) -> list[str]:
    """
    Return human-readable violations if Agent B edited outside:
    - import preamble lines, or
    - regions touched by Agent A for this item (as line intervals in the post-A file).
    """
    base_lines = base_after_a_text.splitlines()
    new_lines = after_b_text.splitlines()

    preamble_end = _import_preamble_end(base_lines)

    def in_allowed_interval(i: int) -> bool:
        for s, e in allowed_intervals_in_after_a:
            if s <= i < e:
                return True
        return False

    def is_import_only_region(i1: int, i2: int) -> bool:
        if i2 <= preamble_end:
            return True
        # Also allow editing individual `import ...` lines even if the preamble is oddly short.
        for i in range(i1, i2):
            if i < 0 or i >= len(base_lines):
                continue
            if base_lines[i].strip().startswith("import "):
                continue
            return False
        return True

    sm = difflib.SequenceMatcher(a=base_lines, b=new_lines)
    violations: list[str] = []
    for tag, i1, i2, j1, j2 in sm.get_opcodes():
        if tag == "equal":
            continue

        if tag == "insert":
            # Insertions don't consume lines from the base; authorize them if inserted into:
            # - import preamble region, OR
            # - an allowed interval boundary/inside.
            if i1 <= preamble_end and all(_is_import_preamble_line(x) for x in new_lines[j1:j2]):
                continue
            if any(s <= i1 <= e for (s, e) in allowed_intervals_in_after_a):
                continue
            violations.append(f"insert at line {i1 + 1} (post-Agent-A) outside allowed regions")
            continue

        # replace/delete consume a base range [i1, i2)
        if is_import_only_region(i1, i2):
            continue
        if all(in_allowed_interval(i) for i in range(i1, i2)):
            continue
        if tag == "delete":
            violations.append(
                f"delete lines {i1 + 1}-{i2} (post-Agent-A) touches protected content"
            )
        else:
            violations.append(
                f"{tag} lines {i1 + 1}-{i2} (post-Agent-A) touches protected content"
            )

    return violations


def _prebuild_project_aggregates(*, project: str, items: list[object]) -> Path:
    """
    Before running a whole statement JSON, pre-create/update aggregation modules:
    - `<project>/Chapters/ChapXX.lean` imports all touched `sectionYY.lean`
    - `<project>/Book.lean` imports only `Chapters/ChapXX.lean`

    Returns the relative path to `<project>/Book.lean` for compilation checks.
    """
    project = (project or "").strip()
    if not project:
        raise ValueError("project must be non-empty")

    book_rel = ensure_book_exists(project=project)

    # Ensure each referenced section file exists so that ChapXX/Book imports are valid.
    chapters: set[int] = set()
    seen_sections: set[tuple[int, int]] = set()
    for it in items:
        chap = int(getattr(it, "chapter"))
        sec = int(getattr(it, "section"))
        chapters.add(chap)
        seen_sections.add((chap, sec))

    for chap, sec in sorted(seen_sections):
        section_abs = LEAN_ROOT / project / "Chapters" / f"Chap{chap:02d}" / f"section{sec:02d}.lean"
        section_abs.parent.mkdir(parents=True, exist_ok=True)
        if not section_abs.exists():
            section_abs.write_text("import Mathlib\n", encoding="utf-8")
        section_rel = section_abs.relative_to(LEAN_ROOT)
        # Create/refresh section aggregates if parts exist (rare in stage1, but safe).
        section_rel = ensure_section_aggregate_exists(section_rel)
        chap_update = ensure_chapter_imports(project=project, section_aggregate_rel=section_rel)
        ensure_book_imports(project=project, chapter_aggregate_rel=chap_update.chapter_rel)

    # Also make sure Book imports all chapter aggregates referenced (even if a chapter had no sections somehow).
    for chap in sorted(chapters):
        chap_rel = Path(project) / "Chapters" / f"Chap{chap:02d}.lean"
        ensure_book_imports(project=project, chapter_aggregate_rel=chap_rel)

    return book_rel


def _non_sorry_warning_blocks(lean_output: str) -> list[str]:
    """
    Extract multi-line warning blocks that do not mention 'sorry'.
    """
    if not lean_output:
        return []

    lines = lean_output.splitlines()
    blocks: list[str] = []
    i = 0
    while i < len(lines):
        line = lines[i]
        lowered = line.lower()
        if "warning" in lowered:
            block_lines = [line]
            block_lines_lower = [lowered]
            i += 1
            # accumulate lines until the next warning starts
            while i < len(lines):
                next_line = lines[i]
                next_lower = next_line.lower()
                if "warning" in next_lower:
                    break
                block_lines.append(next_line)
                block_lines_lower.append(next_lower)
                i += 1
            block_text_lower = "\n".join(block_lines_lower)
            if "sorry" not in block_text_lower:
                blocks.append("\n".join(block_lines).strip())
            continue
        i += 1

    return blocks


def main() -> None:
    parser = argparse.ArgumentParser(description="Run Codex-to-Lean statement orchestrator.")
    parser.add_argument(
        "--start-index",
        type=int,
        default=None,
        help="Start processing from this global index (overrides saved state).",
    )
    parser.add_argument(
        "--max-items",
        type=int,
        default=None,
        help="Process at most this many items (useful for batching).",
    )
    parser.add_argument(
        "--max-b-retries",
        type=int,
        default=3,
        help="Max retries for Agent B when Lean still fails after a fix (default: 3).",
    )
    parser.add_argument(
        "--clean-warnings-with-agent-b",
        action=argparse.BooleanOptionalAction,
        default=False,
        help="Whether to call Agent B to clean non-sorry warnings reported by Lean (default: false).",
    )
    parser.add_argument(
        "--book-check-every",
        type=int,
        default=int(os.getenv("STATEMENT_BOOK_CHECK_EVERY", "10")),
        help=(
            "How often to run `lake env lean <project>/Book.lean` within the same chapter (default: 10). "
            "1 = always (legacy behavior), 0 = never. "
            "Regardless of this value, Book.lean is checked once at chapter boundaries when project mode is enabled."
        ),
    )
    parser.add_argument(
        "--data-file",
        type=Path,
        default=None,
        help="Path to the JSON data file to process (default: config.DATA_PATH).",
    )
    parser.add_argument(
        "--statement-agent-config",
        type=Path,
        default=None,
        help="Path to a TOML file controlling per-agent model/reasoning for STATEMENT stage. "
        "Default: use $STATEMENT_AGENT_CONFIG_FILE, else repo `agent_configs/statement_agents.toml` if present.",
    )
    default_semantic_check = os.getenv("STATEMENT_SEMANTIC_CHECK", "").strip().lower() in {
        "1",
        "true",
        "yes",
        "on",
    }
    parser.add_argument(
        "--semantic-check",
        action=argparse.BooleanOptionalAction,
        default=default_semantic_check,
        help=(
            "Run an extra statement semantic-drift check via Agent C after Lean passes "
            "(default: controlled by $STATEMENT_SEMANTIC_CHECK)."
        ),
    )
    parser.add_argument(
        "--semantic-check-policy",
        choices=["warn", "fix", "fail"],
        default=os.getenv("STATEMENT_SEMANTIC_CHECK_POLICY", "fix").strip().lower() or "fix",
        help=(
            "How to handle semantic-check failures: warn=log and continue; "
            "fix=invoke statement Agent B to revise; fail=stop this run (default: fix)."
        ),
    )
    default_require_concrete_key_defs = os.getenv("STATEMENT_REQUIRE_CONCRETE_KEY_DEFS", "0").strip().lower() in {
        "1",
        "true",
        "yes",
        "on",
        "enable",
        "enabled",
    }
    parser.add_argument(
        "--require-concrete-key-defs",
        action=argparse.BooleanOptionalAction,
        default=default_require_concrete_key_defs,
        help=(
            "Require key local `def`/`abbrev` used by the main labeled declaration to be concrete "
            "(not `:= sorry`). Default: disabled."
        ),
    )
    args = parser.parse_args()

    # statement pipeline 起点：确定从哪个 index 开始
    # 1) 若命令行给出 --start-index，优先使用
    # 2) 否则若环境变量 STATEMENT_START_INDEX 或 START_INDEX 存在，则尝试解析
    # 3) 否则回退到 statement_progress.json 中的 next_index
    env_start = os.environ.get("STATEMENT_START_INDEX") or os.environ.get("START_INDEX")
    start_index = args.start_index
    if start_index is None and env_start:
        try:
            start_index = int(env_start)
        except ValueError:
            pass

    data_path = args.data_file or DATA_PATH
    print(f"Using data file: {data_path}")
    if not Path(data_path).exists():
        print(f"Data file not found: {data_path}")
        print("Put your JSON list under `data/` and pass it via --data-file. See `data/README.md`.")
        return
    items = load_items(data_path=data_path)
    state = load_state()
    default_start = state.get("next_index", 1)
    effective_start = start_index if start_index is not None else default_start

    items = list(filter_items(items, min_index=effective_start))

    project = os.environ.get("FORMAL_PROJECT", "").strip()
    book_rel: Path | None = None
    if project:
        try:
            book_rel = _prebuild_project_aggregates(project=project, items=list(items))
            print(f"[Aggregates] prebuilt project aggregators; book={book_rel}")
        except Exception as e:
            print(f"[Aggregates] prebuild failed: {e}")
            book_rel = None

    book_check_every = int(args.book_check_every or 0)
    if book_check_every < 0:
        book_check_every = 0
    current_chapter: int | None = None
    chapter_item_count = 0

    default_cfg = ROOT / "agent_configs/statement_agents.toml"
    statement_agent_cfg = args.statement_agent_config or (
        Path(os.environ["STATEMENT_AGENT_CONFIG_FILE"]) if os.environ.get("STATEMENT_AGENT_CONFIG_FILE") else None
    )
    agent_settings = resolve_stage_agents_settings(
        stage_prefix="STATEMENT_AGENT",
        agent_keys=["A", "B", "C"],
        config_path=statement_agent_cfg,
        default_config_path=default_cfg,
    )

    run_id = start_run(
        "statement",
        stage=1,
        name_tag=Path(data_path).name,
        data_file=str(data_path),
        extra={
            "start_index": effective_start,
            "clean_warnings_with_agent_b": args.clean_warnings_with_agent_b,
            "book_check_every": book_check_every,
            "require_concrete_key_defs": args.require_concrete_key_defs,
            "statement_agent_config": str(agent_settings.source_path) if agent_settings.source_path else None,
            "statement_agent_a_model": agent_settings.agents["A"].model,
            "statement_agent_a_reasoning_effort": agent_settings.agents["A"].reasoning_effort,
            "statement_agent_b_model": agent_settings.agents["B"].model,
            "statement_agent_b_reasoning_effort": agent_settings.agents["B"].reasoning_effort,
        },
    )
    run_start = time.monotonic()
    processed = 0
    last_success = state.get("next_index", 1) - 1
    total_tokens_used = 0
    total_items_failed = 0

    try:
        for idx, item in enumerate(items):
            next_item = items[idx + 1] if idx + 1 < len(items) else None
            # 每条数据的管线：Agent A 生成 → lean 检查 → Agent B（如需） → lean 复查 → 记录进度
            print(f"=== index={item.index} label={item.label} ({item.chapter}.{item.section}.{item.local_index}) ===")
            item_start = time.monotonic()
            log_event(
                run_id,
                "item_start",
                {
                    "index": item.index,
                    "label": item.label,
                    "chapter": item.chapter,
                    "section": item.section,
                    "local_index": item.local_index,
                },
            )

            lean_file_abs = resolve_item_target_file(
                item.chapter, item.section, label=item.label, target_file=getattr(item, "target_file", None)
            )
            lean_file_abs.parent.mkdir(parents=True, exist_ok=True)
            # 若章节文件不存在则创建，填入最小导入；确保后续 codex/lean 有文件载体
            if not lean_file_abs.exists():
                lean_file_abs.write_text("import Mathlib\n", encoding="utf-8")
            lean_file_rel = lean_file_abs.relative_to(LEAN_ROOT)
            # If we're editing a part file, ensure the aggregate section file exists (for compilation/imports).
            compile_file_rel = compile_entry_for(ensure_section_aggregate_exists(lean_file_rel))

            # 步骤 1：调用 Agent A 生成/追加 Lean 声明骨架
            file_before_a = file_snapshot(lean_file_abs)
            before_text_a = lean_file_abs.read_text(encoding="utf-8")
            agent_a_start = time.monotonic()
            agent_a_res = run_agent_a(
                item,
                model=agent_settings.agents["A"].model,
                reasoning_effort=agent_settings.agents["A"].reasoning_effort,
            )
            agent_a_seconds = time.monotonic() - agent_a_start
            file_after_a = file_snapshot(lean_file_abs)
            after_text_a = lean_file_abs.read_text(encoding="utf-8")
            total_tokens_used += agent_a_res.tokens_used or 0
            item_tokens_used = agent_a_res.tokens_used or 0
            log_event(
                run_id,
                "agent_a_result",
                {
                    "index": item.index,
                    "label": item.label,
                    "code": agent_a_res.code,
                    "seconds": agent_a_seconds,
                    "tokens_used": agent_a_res.tokens_used,
                    "log_path": str(agent_a_res.log_path) if agent_a_res.log_path else None,
                    "file_before": file_before_a,
                    "file_after": file_after_a,
                    "file_delta": snapshot_delta(file_before_a, file_after_a),
                },
            )
            if agent_a_res.code != 0:
                print(f"Agent A failed with code {agent_a_res.code}. Stopping.\n{agent_a_res.stderr}")
                total_items_failed += 1
                log_event(
                    run_id,
                    "item_end",
                    {
                        "index": item.index,
                        "label": item.label,
                        "status": "agent_a_failed",
                        "seconds": time.monotonic() - item_start,
                        "tokens_used_total": item_tokens_used,
                        "error_snippet": agent_a_res.stderr[:1000] if agent_a_res.stderr else None,
                    },
                )
                break
            _maybe_report_math_blocker(stdout=agent_a_res.stdout or "", agent="A", label=item.label)

            axiom_decls = find_axiom_decls(lean_file_abs)
            if axiom_decls:
                print("Detected forbidden `axiom` declarations after Agent A; requesting cleanup.")
                cleanup_instructions = build_axiom_cleanup_instructions(lean_file_rel, axiom_decls)
                cleanup_start = time.monotonic()
                cleanup_res = run_agent_a(
                    item,
                    extra_instructions=cleanup_instructions,
                )
                cleanup_seconds = time.monotonic() - cleanup_start
                total_tokens_used += cleanup_res.tokens_used or 0
                log_event(
                    run_id,
                    "agent_a_axiom_cleanup_result",
                    {
                        "index": item.index,
                        "label": item.label,
                        "code": cleanup_res.code,
                        "seconds": cleanup_seconds,
                        "tokens_used": cleanup_res.tokens_used,
                        "log_path": str(cleanup_res.log_path) if cleanup_res.log_path else None,
                    },
                )
                if cleanup_res.code != 0:
                    print(f"Agent A axiom cleanup failed with code {cleanup_res.code}. Stopping.\n{cleanup_res.stderr}")
                    total_items_failed += 1
                    break
                axiom_decls = find_axiom_decls(lean_file_abs)
                if axiom_decls:
                    print("Axiom cleanup failed; `axiom` declarations still present.")
                    print(format_axiom_report(axiom_decls))
                    total_items_failed += 1
                    break

            # 步骤 2：首次尝试编译，若成功则推进进度
            failed_mode: str | None = None  # "target" | "book" | "semantic" | "key_def" | None
            b_extra_instructions: str | None = None
            force_book_check_in_retry = False
            should_check_book = False
            book_check_reason = None
            if book_rel is not None:
                chap = int(item.chapter)
                if current_chapter is None or current_chapter != chap:
                    current_chapter = chap
                    chapter_item_count = 0
                chapter_item_count += 1
                chapter_boundary = next_item is None or int(next_item.chapter) != chap
                if book_check_every <= 0:
                    should_check_book = chapter_boundary
                    book_check_reason = "chapter_boundary_only" if chapter_boundary else "disabled"
                elif book_check_every == 1:
                    should_check_book = True
                    book_check_reason = "always"
                else:
                    should_check_book = chapter_boundary or (chapter_item_count % book_check_every == 0)
                    book_check_reason = (
                        "chapter_boundary"
                        if chapter_boundary
                        else ("every_n" if should_check_book else "throttled")
                    )

            # Rule: before checking Book.lean, always check the target compile entry first.
            lean_start = time.monotonic()
            code, out, err = lake_env_lean(compile_file_rel)
            lean_seconds = time.monotonic() - lean_start
            lean_output = "\n".join(part for part in (err, out) if part)
            non_sorry_blocks = _non_sorry_warning_blocks(lean_output)
            log_event(
                run_id,
                "lean_check",
                {
                    "index": item.index,
                    "label": item.label,
                    "phase": "post_agent_a_target",
                    "code": code,
                    "seconds": lean_seconds,
                    "has_non_sorry_warnings": bool(non_sorry_blocks),
                    "compiled_file": str(compile_file_rel),
                },
            )
            if code == 0 and non_sorry_blocks and args.clean_warnings_with_agent_b:
                failed_mode = "target"
                warnings_text = "\n\n".join(non_sorry_blocks)
                err = "Lean produced the following non-sorry warnings. Please remove them:\n\n" + warnings_text
            elif code != 0:
                failed_mode = "target"
                # `lake env lean` may print errors to stdout; keep a combined log for Agent B.
                err = lean_output

            if failed_mode is None and book_rel is not None and should_check_book:
                lean_start = time.monotonic()
                book_code, book_out, book_err = lake_env_lean(book_rel)
                book_seconds = time.monotonic() - lean_start
                book_output = "\n".join(part for part in (book_err, book_out) if part)
                book_non_sorry_blocks = _non_sorry_warning_blocks(book_output)
                log_event(
                    run_id,
                    "lean_check",
                    {
                        "index": item.index,
                        "label": item.label,
                        "phase": "post_agent_a_book",
                        "code": book_code,
                        "seconds": book_seconds,
                        "has_non_sorry_warnings": bool(book_non_sorry_blocks),
                        "compiled_file": str(book_rel),
                    },
                )
                if book_code == 0 and book_non_sorry_blocks and args.clean_warnings_with_agent_b:
                    failed_mode = "book"
                    warnings_text = "\n\n".join(book_non_sorry_blocks)
                    err = "Lean produced the following non-sorry warnings (via Book.lean). Please remove them:\n\n" + warnings_text
                elif book_code != 0:
                    failed_mode = "book"
                    out, err = book_out, book_output
                    non_sorry_blocks = book_non_sorry_blocks
            elif failed_mode is None and book_rel is not None and not should_check_book:
                log_event(
                    run_id,
                    "lean_check_skipped",
                    {
                        "index": item.index,
                        "label": item.label,
                        "phase": "post_agent_a_book",
                        "reason": book_check_reason,
                        "chapter": int(item.chapter),
                        "chapter_item_count": chapter_item_count,
                        "book_check_every": book_check_every,
                        "compiled_file": str(book_rel),
                    },
                )

            if failed_mode is None and (not non_sorry_blocks or not args.clean_warnings_with_agent_b):
                key_failed_mode, key_err, key_b_extra = _key_def_placeholders_and_maybe_set_failure(
                    run_id=run_id,
                    item=item,
                    lean_file_abs=lean_file_abs,
                    require_concrete_key_defs=bool(args.require_concrete_key_defs),
                )
                if key_failed_mode is not None:
                    failed_mode = key_failed_mode
                    err = key_err or err
                    b_extra_instructions = key_b_extra

            if failed_mode is None and (not non_sorry_blocks or not args.clean_warnings_with_agent_b):
                semantic_tokens_ref = [0]
                sem_failed_mode, sem_err, sem_b_extra = _semantic_check_and_maybe_set_failure(
                    run_id=run_id,
                    agent_settings=agent_settings,
                    item=item,
                    lean_file_abs=lean_file_abs,
                    compile_file_rel=compile_file_rel,
                    args=args,
                    total_tokens_used_ref=semantic_tokens_ref,
                )
                total_tokens_used += semantic_tokens_ref[0]
                item_tokens_used += semantic_tokens_ref[0]
                if sem_failed_mode is not None:
                    if args.semantic_check_policy == "fail":
                        print(sem_err or "Statement semantic check failed.")
                        total_items_failed += 1
                        state["next_index"] = item.index
                        save_state(state, run_id=run_id)
                        log_event(
                            run_id,
                            "item_end",
                            {
                                "index": item.index,
                                "label": item.label,
                                "status": "semantic_failed",
                                "seconds": time.monotonic() - item_start,
                                "tokens_used_total": item_tokens_used,
                                "error_snippet": (sem_err or "")[:1000] or None,
                            },
                        )
                        break
                    failed_mode = sem_failed_mode
                    err = sem_err or err
                    b_extra_instructions = sem_b_extra

            if failed_mode is None and (not non_sorry_blocks or not args.clean_warnings_with_agent_b):
                if non_sorry_blocks and not args.clean_warnings_with_agent_b:
                    print("Lean OK, but has non-sorry warnings (--no-clean-warnings-with-agent-b); continuing.")
                    status = "ok_with_warnings"
                else:
                    print("Lean OK.")
                    status = "ok"
                state["next_index"] = item.index + 1
                save_state(state, run_id=run_id)
                processed += 1
                last_success = item.index
                log_event(
                    run_id,
                    "item_end",
                    {
                        "index": item.index,
                        "label": item.label,
                        "status": status,
                        "seconds": time.monotonic() - item_start,
                        "tokens_used_total": item_tokens_used,
                    },
                )
                if args.max_items is not None and processed >= args.max_items:
                    print(f"Reached max-items={args.max_items}, stopping batch.")
                    break
                continue

            if failed_mode == "book":
                print("Book.lean check failed (or has non-sorry warnings); calling Agent B (Book-check mode)...")
            elif failed_mode == "semantic":
                print("Statement semantic check flagged drift; calling Agent B (semantic-fix mode)...")
            elif failed_mode == "key_def":
                print("Key-definition placeholder check failed; calling Agent B (semantic-fix mode)...")
            elif code == 0 and non_sorry_blocks and args.clean_warnings_with_agent_b:
                print("Lean produced warnings (non-sorry); invoking Agent B to clean.")
            else:
                print("Lean failed, calling Agent B...")
            # 步骤 3：编译失败则转交 Agent B 进行修复（可多次尝试）
            success_after_b = False
            ok_after_b_status = "ok_after_b"
            force_book_check_in_retry = bool(book_rel is not None and failed_mode == "book")
            semantic_hard_failure = False
            restrict_b_scope = os.getenv("STATEMENT_AGENT_B_SCOPE", "").strip().lower() in {
                "restricted",
                "restrict",
                "1",
                "true",
                "yes",
            }
            # Optional guardrail: restrict Agent B edits to the newly-touched region plus import preamble.
            # Default is permissive ("free") so Agent B can repair earlier declarations when necessary.
            b_allowed_intervals = (
                _compute_agent_b_allowed_intervals(before_text=before_text_a, after_text=after_text_a)
                if restrict_b_scope
                else []
            )
            b_approved_text = after_text_a
            for attempt in range(1, args.max_b_retries + 1):
                print(f"[Agent B attempt {attempt}/{args.max_b_retries}]")
                # In restricted mode, always start from the last approved text to prevent drift.
                if restrict_b_scope and lean_file_abs.read_text(encoding="utf-8") != b_approved_text:
                    lean_file_abs.write_text(b_approved_text, encoding="utf-8")
                file_before_b = file_snapshot(lean_file_abs)
                b_start = time.monotonic()
                run_b = (
                    run_agent_b_bookcheck
                    if (book_rel is not None and failed_mode == "book")
                    else run_agent_b
                )
                before_text_b = lean_file_abs.read_text(encoding="utf-8")
                b_res = run_b(
                    lean_file_abs,
                    error_log=err,
                    item_index=item.index,
                    item_context=item.context,
                    item_dependencies=item.dependencies,
                    label=item.label,
                    model=agent_settings.agents["B"].model,
                    reasoning_effort=agent_settings.agents["B"].reasoning_effort,
                    extra_instructions=b_extra_instructions,
                )
                _maybe_report_math_blocker(stdout=b_res.stdout or "", agent="B", label=item.label)
                after_text_b = lean_file_abs.read_text(encoding="utf-8")
                scope_violations: list[str] = []
                if restrict_b_scope:
                    scope_violations = _agent_b_edit_scope_violations(
                        base_after_a_text=after_text_a,
                        after_b_text=after_text_b,
                        allowed_intervals_in_after_a=b_allowed_intervals,
                    )
                    if scope_violations:
                        # Revert and ask Agent B to retry within scope.
                        lean_file_abs.write_text(b_approved_text, encoding="utf-8")
                        after_text_b = b_approved_text
                        err = (
                            "Policy violation: Agent B edited outside the current item's new/modified declarations.\n\n"
                            "Allowed edits:\n"
                            "- import preamble (necessary `import ...` only)\n"
                            "- declaration blocks inserted/modified by Agent A for this item\n\n"
                            "Violations:\n- "
                            + "\n- ".join(scope_violations)
                        )
                        log_event(
                            run_id,
                            "agent_b_edit_scope_violation",
                            {
                                "index": item.index,
                                "label": item.label,
                                "attempt": attempt,
                                "violations": scope_violations,
                            },
                        )
                b_seconds = time.monotonic() - b_start
                file_after_b = file_snapshot(lean_file_abs)
                total_tokens_used += b_res.tokens_used or 0
                item_tokens_used += b_res.tokens_used or 0
                log_event(
                    run_id,
                    "agent_b_result",
                    {
                        "index": item.index,
                        "label": item.label,
                        "attempt": attempt,
                        "code": b_res.code,
                        "seconds": b_seconds,
                        "tokens_used": b_res.tokens_used,
                        "log_path": str(b_res.log_path) if b_res.log_path else None,
                        "file_before": file_before_b,
                        "file_after": file_after_b,
                        "file_delta": snapshot_delta(file_before_b, file_after_b),
                    },
                )
                if b_res.code != 0:
                    print(f"Agent B failed with code {b_res.code}. Stopping.\n{b_res.stderr}")
                    success_after_b = False
                    break
                if scope_violations:
                    print("Agent B edited outside allowed scope; retrying.")
                    continue

                if restrict_b_scope:
                    b_approved_text = after_text_b

                # 步骤 4：复查编译；成功则更新进度，失败则视尝试次数决定是否继续
                failed_mode = None
                lean_start = time.monotonic()
                code, out, err = lake_env_lean(compile_file_rel)
                lean_seconds = time.monotonic() - lean_start
                lean_output = "\n".join(part for part in (err, out) if part)
                non_sorry_blocks = _non_sorry_warning_blocks(lean_output)
                log_event(
                    run_id,
                    "lean_check",
                    {
                        "index": item.index,
                        "label": item.label,
                        "phase": f"post_agent_b_attempt{attempt}",
                        "code": code,
                        "seconds": lean_seconds,
                        "has_non_sorry_warnings": bool(non_sorry_blocks),
                        "compiled_file": str(compile_file_rel),
                    },
                )
                if code == 0 and non_sorry_blocks and args.clean_warnings_with_agent_b:
                    failed_mode = "target"
                    warnings_text = "\n\n".join(non_sorry_blocks)
                    err = "Lean still reports these non-sorry warnings. Please remove them:\n\n" + warnings_text
                elif code != 0:
                    failed_mode = "target"
                    # `lake env lean` may print errors to stdout; keep a combined log for Agent B.
                    err = lean_output

                if failed_mode is None and book_rel is not None and (should_check_book or force_book_check_in_retry):
                    lean_start = time.monotonic()
                    book_code, book_out, book_err = lake_env_lean(book_rel)
                    book_seconds = time.monotonic() - lean_start
                    book_output = "\n".join(part for part in (book_err, book_out) if part)
                    book_non_sorry_blocks = _non_sorry_warning_blocks(book_output)
                    log_event(
                        run_id,
                        "lean_check",
                        {
                            "index": item.index,
                            "label": item.label,
                            "phase": f"post_agent_b_attempt{attempt}_book",
                            "code": book_code,
                            "seconds": book_seconds,
                            "has_non_sorry_warnings": bool(book_non_sorry_blocks),
                            "compiled_file": str(book_rel),
                        },
                    )
                    if book_code == 0 and book_non_sorry_blocks and args.clean_warnings_with_agent_b:
                        failed_mode = "book"
                        warnings_text = "\n\n".join(book_non_sorry_blocks)
                        err = (
                            "Lean still reports these non-sorry warnings (via Book.lean). Please remove them:\n\n"
                            + warnings_text
                        )
                    elif book_code != 0:
                        failed_mode = "book"
                        out, err = book_out, book_output
                        non_sorry_blocks = book_non_sorry_blocks
                elif failed_mode is None and book_rel is not None and not (should_check_book or force_book_check_in_retry):
                    log_event(
                        run_id,
                        "lean_check_skipped",
                        {
                            "index": item.index,
                            "label": item.label,
                            "phase": f"post_agent_b_attempt{attempt}_book",
                            "reason": book_check_reason,
                            "chapter": int(item.chapter),
                            "chapter_item_count": chapter_item_count,
                            "book_check_every": book_check_every,
                            "compiled_file": str(book_rel),
                        },
                    )

                if failed_mode is None and code == 0 and (
                    not non_sorry_blocks or not args.clean_warnings_with_agent_b
                ):
                    key_failed_mode, key_err, key_b_extra = _key_def_placeholders_and_maybe_set_failure(
                        run_id=run_id,
                        item=item,
                        lean_file_abs=lean_file_abs,
                        require_concrete_key_defs=bool(args.require_concrete_key_defs),
                    )
                    if key_failed_mode is not None:
                        failed_mode = key_failed_mode
                        err = key_err or err
                        b_extra_instructions = key_b_extra
                        print("Key-definition placeholder check still failing; retrying Agent B if attempts remain.")
                        continue

                    semantic_tokens_ref = [0]
                    sem_failed_mode, sem_err, sem_b_extra = _semantic_check_and_maybe_set_failure(
                        run_id=run_id,
                        agent_settings=agent_settings,
                        item=item,
                        lean_file_abs=lean_file_abs,
                        compile_file_rel=compile_file_rel,
                        args=args,
                        total_tokens_used_ref=semantic_tokens_ref,
                    )
                    total_tokens_used += semantic_tokens_ref[0]
                    item_tokens_used += semantic_tokens_ref[0]
                    if sem_failed_mode is not None:
                        if args.semantic_check_policy == "fail":
                            print(sem_err or "Statement semantic check failed after Agent B.")
                            total_items_failed += 1
                            state["next_index"] = item.index
                            save_state(state, run_id=run_id)
                            log_event(
                                run_id,
                                "item_end",
                                {
                                    "index": item.index,
                                    "label": item.label,
                                    "status": "semantic_failed_after_b",
                                    "seconds": time.monotonic() - item_start,
                                    "tokens_used_total": item_tokens_used,
                                    "error_snippet": (sem_err or "")[:1000] or None,
                                },
                            )
                            semantic_hard_failure = True
                            break
                        failed_mode = sem_failed_mode
                        err = sem_err or err
                        b_extra_instructions = sem_b_extra
                        print("Semantic drift persists; retrying Agent B if attempts remain.")
                        continue

                    if non_sorry_blocks and not args.clean_warnings_with_agent_b:
                        print(
                            "Lean OK after Agent B, but has non-sorry warnings (--no-clean-warnings-with-agent-b); continuing."
                        )
                        ok_after_b_status = "ok_after_b_with_warnings"
                    else:
                        print("Lean OK after Agent B.")
                        ok_after_b_status = "ok_after_b"
                    state["next_index"] = item.index + 1
                    save_state(state, run_id=run_id)
                    processed += 1
                    last_success = item.index
                    if args.max_items is not None and processed >= args.max_items:
                        print(f"Reached max-items={args.max_items}, stopping batch.")
                        success_after_b = True
                        break
                    success_after_b = True
                    break

                if failed_mode == "book":
                    print("Book.lean check still failing after Agent B; will retry if attempts remain.")
                elif code == 0 and non_sorry_blocks and args.clean_warnings_with_agent_b:
                    print("Lean still has non-sorry warnings after Agent B attempt; will retry if attempts remain.")
                else:
                    print("Lean still failing after Agent B attempt; will retry if attempts remain.")

            if semantic_hard_failure:
                break

            if success_after_b:
                log_event(
                    run_id,
                    "item_end",
                    {
                        "index": item.index,
                        "label": item.label,
                        "status": ok_after_b_status,
                        "seconds": time.monotonic() - item_start,
                        "tokens_used_total": item_tokens_used,
                    },
                )
                if args.max_items is not None and processed >= args.max_items:
                    break
                continue

            print("Lean still failing after Agent B retries.")
            print(err)
            total_items_failed += 1
            state["next_index"] = item.index
            save_state(state, run_id=run_id)
            log_event(
                run_id,
                "item_end",
                {
                    "index": item.index,
                    "label": item.label,
                    "status": "failed_after_b",
                    "seconds": time.monotonic() - item_start,
                    "tokens_used_total": item_tokens_used,
                    "error_snippet": err[:1000] if err else None,
                },
            )
            break
    finally:
        finish_run(
            run_id,
            {
                "pipeline": "statement",
                "stage": 1,
                "run_id": run_id,
                "data_file": str(data_path),
                "processed": processed,
                "last_success_index": last_success,
                "next_index": state.get("next_index", 1),
                "items_failed": total_items_failed,
                "tokens_used_total": total_tokens_used,
                "require_concrete_key_defs": args.require_concrete_key_defs,
                "seconds_total": time.monotonic() - run_start,
                "paths": {
                    "progress_file": str(STATEMENT_PROGRESS_FILE),
                    "statement_logs_dir": str(STATEMENT_LOGS_DIR),
                    "metrics_dir": str(METRICS_DIR),
                },
            },
        )

    print(f"Processed {processed} items. Last successful index: {last_success}. Next index: {state.get('next_index', 1)}.")


if __name__ == "__main__":
    main()
