from __future__ import annotations

import os
from dataclasses import dataclass
from pathlib import Path


class StageAgentsError(RuntimeError):
    pass


@dataclass(frozen=True, slots=True)
class StageAgentsPaths:
    stage: str
    src: Path
    dst: Path


def _stage_agents_paths(*, lean_root: Path, stage: str) -> StageAgentsPaths:
    s = (stage or "").strip().lower()
    if s not in {
        "statement",
        "proof",
        "final",
        "infra",
        "infra_statement",
        "infra_proof",
        "infra_final",
    }:
        raise StageAgentsError(f"Unknown stage: {stage!r}")
    src = lean_root / f".codex_{s}" / "AGENTS.md"
    dst = lean_root / "AGENTS.md"
    return StageAgentsPaths(stage=s, src=src, dst=dst)


def ensure_stage_agents(*, lean_root: Path, stage: str) -> StageAgentsPaths:
    """
    Ensure `LEAN_ROOT/AGENTS.md` matches the stage-specific AGENTS file.

    Rationale:
    - Run Codex from `LEAN_ROOT` so its sandbox "workspace-write" can edit any Lean file under `M2F/`.
    - Still enforce stage-specific rules by swapping `LEAN_ROOT/AGENTS.md` before each Codex call.
    """
    paths = _stage_agents_paths(lean_root=lean_root, stage=stage)
    if not paths.src.exists():
        raise StageAgentsError(f"Missing stage AGENTS.md: {paths.src}")

    want = paths.src.read_text(encoding="utf-8")
    have = paths.dst.read_text(encoding="utf-8") if paths.dst.exists() else None
    if have == want:
        return paths

    tmp = paths.dst.with_name(f".AGENTS.md.tmp.{os.getpid()}")
    tmp.write_text(want, encoding="utf-8")
    tmp.replace(paths.dst)
    return paths
