from __future__ import annotations

import json
import re
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any


def _utc_now_iso() -> str:
    return datetime.now(timezone.utc).isoformat()


def module_import_for_rel_lean_file(path_rel: Path) -> str:
    module = ".".join(path_rel.with_suffix("").parts)
    return f"import {module}"


def safe_decl_name(text: str, *, fallback: str = "infra_item_frozen") -> str:
    cleaned = re.sub(r"[^A-Za-z0-9_']", "_", text).strip("_")
    if not cleaned:
        return fallback
    if cleaned[0].isdigit():
        cleaned = f"n_{cleaned}"
    return cleaned


@dataclass(frozen=True, slots=True)
class PrefixStoreConfig:
    chunk_item_limit: int = 20
    chunk_line_limit: int = 1800


class InfraPrefixStore:
    """
    Append-only frozen prefix storage.

    Files are sharded under:
      <infra_dir>/GeneratedPrefix/Prefix_0001.lean
      <infra_dir>/GeneratedPrefix/Prefix_0002.lean
      ...
    and indexed by:
      <infra_dir>/PrefixIndex.lean
      <infra_dir>/GeneratedPrefix/prefix_manifest.json
    """

    def __init__(
        self,
        *,
        lean_root: Path,
        infra_dir_rel: Path,
        config: PrefixStoreConfig | None = None,
    ) -> None:
        self.lean_root = lean_root
        self.infra_dir_rel = infra_dir_rel
        self.infra_dir_abs = self.lean_root / self.infra_dir_rel
        self.prefix_dir_rel = self.infra_dir_rel / "GeneratedPrefix"
        self.prefix_dir_abs = self.infra_dir_abs / "GeneratedPrefix"
        self.prefix_index_rel = self.infra_dir_rel / "PrefixIndex.lean"
        self.prefix_index_abs = self.infra_dir_abs / "PrefixIndex.lean"
        self.manifest_rel = self.prefix_dir_rel / "prefix_manifest.json"
        self.manifest_abs = self.prefix_dir_abs / "prefix_manifest.json"
        self.config = config or PrefixStoreConfig()

    def _default_manifest(self) -> dict[str, Any]:
        return {
            "version": 1,
            "infra_dir": str(self.infra_dir_rel),
            "chunk_item_limit": int(self.config.chunk_item_limit),
            "chunk_line_limit": int(self.config.chunk_line_limit),
            "chunks": [],
            "frozen_items": [],
            "updated_at": _utc_now_iso(),
        }

    def _normalize_manifest(self, raw: Any) -> dict[str, Any]:
        if not isinstance(raw, dict):
            return self._default_manifest()
        out = self._default_manifest()
        out["version"] = int(raw.get("version", out["version"]))
        out["infra_dir"] = str(raw.get("infra_dir", out["infra_dir"]))
        out["chunk_item_limit"] = int(raw.get("chunk_item_limit", out["chunk_item_limit"]))
        out["chunk_line_limit"] = int(raw.get("chunk_line_limit", out["chunk_line_limit"]))
        chunks = raw.get("chunks")
        if isinstance(chunks, list):
            out["chunks"] = [c for c in chunks if isinstance(c, dict)]
        frozen_items = raw.get("frozen_items")
        if isinstance(frozen_items, list):
            out["frozen_items"] = [it for it in frozen_items if isinstance(it, dict)]
        out["updated_at"] = str(raw.get("updated_at", out["updated_at"]))
        return out

    def load_manifest(self) -> dict[str, Any]:
        if not self.manifest_abs.exists():
            return self._default_manifest()
        try:
            raw = json.loads(self.manifest_abs.read_text(encoding="utf-8"))
        except Exception:
            return self._default_manifest()
        return self._normalize_manifest(raw)

    def save_manifest(self, manifest: dict[str, Any]) -> None:
        manifest["updated_at"] = _utc_now_iso()
        self.manifest_abs.parent.mkdir(parents=True, exist_ok=True)
        self.manifest_abs.write_text(
            json.dumps(manifest, ensure_ascii=False, indent=2) + "\n",
            encoding="utf-8",
        )

    def ensure_layout(self) -> None:
        self.infra_dir_abs.mkdir(parents=True, exist_ok=True)
        self.prefix_dir_abs.mkdir(parents=True, exist_ok=True)
        manifest = self.load_manifest()
        self.save_manifest(manifest)
        self._refresh_prefix_index(manifest)

    def _chunk_rel_path(self, seq: int) -> Path:
        return self.prefix_dir_rel / f"Prefix_{int(seq):04d}.lean"

    def _chunk_abs_path(self, seq: int) -> Path:
        return self.lean_root / self._chunk_rel_path(seq)

    def _create_chunk_file(self, *, seq: int, prev_chunk_rel: Path | None) -> Path:
        rel = self._chunk_rel_path(seq)
        abs_path = self.lean_root / rel
        lines = ["import Mathlib"]
        if prev_chunk_rel is not None:
            lines.append(module_import_for_rel_lean_file(prev_chunk_rel))
        lines.extend(
            [
                "",
                "-- Auto-generated frozen prefix chunk.",
                "-- Keep append-only; do not edit by hand.",
                "",
            ]
        )
        abs_path.parent.mkdir(parents=True, exist_ok=True)
        abs_path.write_text("\n".join(lines), encoding="utf-8")
        return rel

    def _choose_chunk(self, *, manifest: dict[str, Any], add_lines: int) -> dict[str, Any]:
        chunks = manifest.get("chunks", [])
        if not isinstance(chunks, list):
            chunks = []
            manifest["chunks"] = chunks

        if chunks:
            last = chunks[-1]
            if isinstance(last, dict):
                item_count = int(last.get("item_count", 0))
                line_count = int(last.get("line_count", 0))
                need_new_chunk = False
                if int(self.config.chunk_item_limit) > 0 and item_count >= int(self.config.chunk_item_limit):
                    need_new_chunk = True
                if (
                    not need_new_chunk
                    and int(self.config.chunk_line_limit) > 0
                    and item_count > 0
                    and (line_count + add_lines) > int(self.config.chunk_line_limit)
                ):
                    need_new_chunk = True
                if not need_new_chunk:
                    return last

        seq = len(chunks) + 1
        prev_rel = None
        if chunks:
            prev = chunks[-1]
            prev_rel_raw = str(prev.get("rel_path", "")).strip() if isinstance(prev, dict) else ""
            if prev_rel_raw:
                prev_rel = Path(prev_rel_raw)

        rel = self._create_chunk_file(seq=seq, prev_chunk_rel=prev_rel)
        rec = {
            "seq": int(seq),
            "rel_path": str(rel),
            "item_count": 0,
            "line_count": 0,
            "item_ids": [],
            "labels": [],
        }
        chunks.append(rec)
        return rec

    def _refresh_prefix_index(self, manifest: dict[str, Any]) -> None:
        chunks = manifest.get("chunks", [])
        imports = ["import Mathlib"]
        if isinstance(chunks, list):
            for entry in chunks:
                if not isinstance(entry, dict):
                    continue
                rel_raw = str(entry.get("rel_path", "")).strip()
                if not rel_raw:
                    continue
                imports.append(module_import_for_rel_lean_file(Path(rel_raw)))
        imports.extend(
            [
                "",
                "-- Auto-generated by orchestrator.infra_prefix_store.",
                "-- This module should only aggregate frozen prefix chunks.",
                "",
            ]
        )
        self.prefix_index_abs.parent.mkdir(parents=True, exist_ok=True)
        self.prefix_index_abs.write_text("\n".join(imports), encoding="utf-8")

    def _ensure_primary_chunk_record(self, *, manifest: dict[str, Any]) -> dict[str, Any]:
        """
        Ensure chunk seq=1 exists in manifest/files so callers can place auxiliary
        declarations into GeneratedPrefix/Prefix_0001.lean without touching frozen_items.
        """
        chunks_raw = manifest.get("chunks")
        chunks = [c for c in chunks_raw if isinstance(c, dict)] if isinstance(chunks_raw, list) else []
        manifest["chunks"] = chunks

        for rec in chunks:
            if int(rec.get("seq", 0)) == 1:
                rel_raw = str(rec.get("rel_path", "")).strip()
                if rel_raw:
                    rel = Path(rel_raw)
                    abs_path = self.lean_root / rel
                    if not abs_path.exists():
                        self._create_chunk_file(seq=1, prev_chunk_rel=None)
                    return rec

        rel = self._create_chunk_file(seq=1, prev_chunk_rel=None)
        rec = {
            "seq": 1,
            "rel_path": str(rel),
            "item_count": 0,
            "line_count": 0,
            "item_ids": [],
            "labels": [],
        }
        chunks.insert(0, rec)
        self.save_manifest(manifest)
        self._refresh_prefix_index(manifest)
        return rec

    def append_verified_item(
        self,
        *,
        item_id: str,
        label: str,
        declaration_text: str,
        metadata: dict[str, Any] | None = None,
    ) -> Path:
        """
        Append a verified declaration to the latest chunk (or open a new chunk).
        Returns the relative chunk path under LEAN_ROOT.
        """
        self.ensure_layout()
        manifest = self.load_manifest()

        decl = declaration_text.strip()
        if not decl:
            raise ValueError("declaration_text must be non-empty")

        meta_payload = {
            "item_id": item_id,
            "label": label,
            "metadata": metadata or {},
        }
        block = "\n".join(
            [
                f"/- PREFIX_ITEM_META {json.dumps(meta_payload, ensure_ascii=False, sort_keys=True)} -/",
                decl,
                "",
            ]
        )
        add_lines = len(block.splitlines())

        chunk = self._choose_chunk(manifest=manifest, add_lines=add_lines)
        rel = Path(str(chunk.get("rel_path", "")))
        if not rel.parts:
            raise ValueError("invalid chunk record: missing rel_path")
        abs_path = self.lean_root / rel
        if not abs_path.exists():
            seq = int(chunk.get("seq", 1))
            prev_rel = None
            chunks = manifest.get("chunks", [])
            if isinstance(chunks, list) and len(chunks) >= 2 and seq > 1:
                prev_entry = chunks[seq - 2]
                if isinstance(prev_entry, dict):
                    prev_raw = str(prev_entry.get("rel_path", "")).strip()
                    if prev_raw:
                        prev_rel = Path(prev_raw)
            self._create_chunk_file(seq=seq, prev_chunk_rel=prev_rel)

        current = abs_path.read_text(encoding="utf-8")
        if current and not current.endswith("\n"):
            current += "\n"
        abs_path.write_text(current + block, encoding="utf-8")

        chunk["item_count"] = int(chunk.get("item_count", 0)) + 1
        chunk["line_count"] = int(chunk.get("line_count", 0)) + add_lines
        item_ids = chunk.get("item_ids")
        if not isinstance(item_ids, list):
            item_ids = []
            chunk["item_ids"] = item_ids
        item_ids.append(item_id)
        labels = chunk.get("labels")
        if not isinstance(labels, list):
            labels = []
            chunk["labels"] = labels
        labels.append(label)

        frozen_items = manifest.get("frozen_items")
        if not isinstance(frozen_items, list):
            frozen_items = []
            manifest["frozen_items"] = frozen_items
        frozen_items.append(
            {
                "item_id": item_id,
                "label": label,
                "chunk_seq": int(chunk.get("seq", 1)),
                "chunk_path": str(rel),
                "metadata": metadata or {},
            }
        )

        self.save_manifest(manifest)
        self._refresh_prefix_index(manifest)
        return rel

    def append_aux_block_to_primary_chunk(
        self,
        *,
        text_block: str,
        metadata: dict[str, Any] | None = None,
    ) -> Path:
        """
        Append a verified auxiliary declaration block to Prefix_0001.lean.
        This does NOT change `frozen_items` and therefore does not advance cursor.
        """
        self.ensure_layout()
        manifest = self.load_manifest()
        primary = self._ensure_primary_chunk_record(manifest=manifest)
        rel = Path(str(primary.get("rel_path", "")).strip())
        if not rel.parts:
            raise ValueError("invalid primary chunk rel_path")
        abs_path = self.lean_root / rel
        if not abs_path.exists():
            self._create_chunk_file(seq=1, prev_chunk_rel=None)

        block = text_block.strip()
        if not block:
            raise ValueError("text_block must be non-empty")

        meta_payload = metadata or {}
        decorated = "\n".join(
            [
                f"/- PREFIX_AUX_META {json.dumps(meta_payload, ensure_ascii=False, sort_keys=True)} -/",
                block,
                "",
            ]
        )
        current = abs_path.read_text(encoding="utf-8")
        if current and not current.endswith("\n"):
            current += "\n"
        abs_path.write_text(current + decorated, encoding="utf-8")
        return rel

    def _split_chunk_header_and_blocks(self, text: str) -> tuple[str, list[str]]:
        pattern = re.compile(r"(?m)^/- PREFIX_ITEM_META .+ -/\n")
        matches = list(pattern.finditer(text))
        if not matches:
            return text, []
        header = text[: matches[0].start()]
        blocks: list[str] = []
        for i, m in enumerate(matches):
            start = m.start()
            end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
            blocks.append(text[start:end])
        return header, blocks

    def rewind_to_prefix_count(self, *, keep_count: int) -> None:
        """
        Truncate frozen suffix to keep only the first `keep_count` frozen items.
        This is used when plan_active has been rewritten and verified prefix needs realignment.
        """
        if keep_count < 0:
            raise ValueError("keep_count must be non-negative")

        self.ensure_layout()
        manifest = self.load_manifest()
        frozen_items_raw = manifest.get("frozen_items")
        frozen_items = (
            [it for it in frozen_items_raw if isinstance(it, dict)]
            if isinstance(frozen_items_raw, list)
            else []
        )
        total = len(frozen_items)
        if keep_count >= total:
            return

        kept_items = frozen_items[:keep_count]
        keep_by_chunk: dict[str, list[dict[str, Any]]] = {}
        for rec in kept_items:
            chunk_path = str(rec.get("chunk_path", "")).strip()
            if not chunk_path:
                continue
            keep_by_chunk.setdefault(chunk_path, []).append(rec)

        chunks_raw = manifest.get("chunks")
        chunks = [c for c in chunks_raw if isinstance(c, dict)] if isinstance(chunks_raw, list) else []
        new_chunks: list[dict[str, Any]] = []

        for old_chunk in chunks:
            rel_raw = str(old_chunk.get("rel_path", "")).strip()
            if not rel_raw:
                continue
            kept_chunk_items = keep_by_chunk.get(rel_raw, [])
            rel = Path(rel_raw)
            abs_path = self.lean_root / rel

            if not kept_chunk_items:
                try:
                    if abs_path.exists():
                        abs_path.unlink()
                except Exception:
                    pass
                continue

            if not abs_path.exists():
                raise FileNotFoundError(f"Missing chunk while rewinding prefix store: {abs_path}")
            text = abs_path.read_text(encoding="utf-8")
            header, blocks = self._split_chunk_header_and_blocks(text)
            keep_blocks = len(kept_chunk_items)
            if keep_blocks > len(blocks):
                raise ValueError(
                    f"Chunk {rel_raw} has fewer item blocks ({len(blocks)}) than requested keep count ({keep_blocks})"
                )
            kept_blocks = blocks[:keep_blocks]
            rebuilt = header + "".join(kept_blocks)
            abs_path.write_text(rebuilt, encoding="utf-8")

            line_count = sum(len(block.splitlines()) for block in kept_blocks)
            new_chunks.append(
                {
                    "seq": int(old_chunk.get("seq", len(new_chunks) + 1)),
                    "rel_path": rel_raw,
                    "item_count": keep_blocks,
                    "line_count": int(line_count),
                    "item_ids": [str(it.get("item_id", "")).strip() for it in kept_chunk_items],
                    "labels": [str(it.get("label", "")).strip() for it in kept_chunk_items],
                }
            )

        manifest["chunks"] = new_chunks
        manifest["frozen_items"] = kept_items
        self.save_manifest(manifest)
        self._refresh_prefix_index(manifest)
