"""Agent runtime — LLM plans, deterministic poll, LLM-guided split.

Architecture:
  - LLM plans the initial spec (start).
  - Poll/evaluate loop is deterministic (no LLM).
  - When a piece stagnates past hard thresholds, LLM is asked whether/how to split.
  - Heartbeat resumes skip LLM entirely.

Public API:
  - Runtime(llm, state_manager, state?, tool_client?)
  - Runtime.start(user_request, exp_id?) -> AgentState
  - Runtime.from_state(state, state_manager, tool_client?) -> Runtime
  - Runtime.step(observation) -> AgentState
"""
from __future__ import annotations

import json
import os
import time
import uuid
from typing import Any, Dict, List, Optional

from agent.actions import ActionType, parse_action
from agent.action_executor import (
    update_spec_action,
    write_spec_action,
)
from agent.llm import (
    DEFAULT_REASONING_EFFORT,
    LLMClient,
    _build_llm_from_env,
    build_production_llm_namespace,
)
from agent.piece_manager import PieceManager, safe_float
from agent.prompts import build_system_prompt
from agent.request_constraints import extract_request_constraints
from agent.state import AgentState
from agent import state_machine
from agent.state_manager import StateManager
from agent.tooling import ToolClient, build_internal_tool_client
from agent.trace import TraceLogger, wrap_tool_client
from server.run_manager import make_response

MIN_COMPLETED_TASKS_PER_WORKER = 15
STAGNATION_ROUNDS_FOR_SPLIT = 30


class Runtime:
    """Deterministic agent runtime for DAG search experiments."""

    def __init__(
        self,
        llm_client: Optional[LLMClient] = None,
        state_manager: Optional[StateManager] = None,
        state: Optional[AgentState] = None,
        tool_client: Optional[ToolClient] = None,
    ):
        self.llm = llm_client
        self.state_manager = state_manager or StateManager()
        self.state = state
        self.trace = TraceLogger.from_env(
            exp_id_getter=lambda: self.state.exp_id if self.state else None,
            state_manager=self.state_manager,
            component="runtime",
        )
        self._tool_client = tool_client
        self._tools: Optional[ToolClient] = None
        self._pm: Optional[PieceManager] = None

    @property
    def tools(self) -> ToolClient:
        if self._tools is None:
            base = self._tool_client or build_internal_tool_client()
            self._tools = wrap_tool_client(base, self.trace)
        return self._tools

    @property
    def pm(self) -> PieceManager:
        if self._pm is None:
            self._pm = PieceManager(self.state, self.state_manager, self.tools)
        return self._pm

    def _log(self, stream: str, event: Dict[str, Any]) -> None:
        if self.trace and self.trace.enabled:
            self.trace.log(stream, event)

    def _save(self) -> None:
        if self.state and self.state.exp_id:
            self.state.updated_at = int(time.time())
            self.state_manager.save(self.state.exp_id, self.state)

    # ── Public API ─────────────────────────────────────────────────

    @classmethod
    def from_state(
        cls,
        state: AgentState,
        state_manager: StateManager,
        llm_client: Optional[LLMClient] = None,
        tool_client: Optional[ToolClient] = None,
    ) -> "Runtime":
        return cls(llm_client=llm_client, state_manager=state_manager, state=state, tool_client=tool_client)

    def _ensure_llm(self) -> LLMClient:
        if self.llm is not None:
            return self.llm
        llm_cfg: Dict[str, Any] = {}
        try:
            llm_cfg = self.state_manager.load_config(self.state.exp_id).get("llm", {})
        except Exception:
            llm_cfg = {}
        args = build_production_llm_namespace(
            llm_api_base=llm_cfg.get("api_base"),
            llm_api_key=None,
            llm_model=llm_cfg.get("model"),
            llm_timeout_s=llm_cfg.get("timeout_s"),
            llm_enable_thinking=llm_cfg.get("enable_thinking"),
            llm_reasoning_effort=llm_cfg.get("reasoning_effort"),
        )
        try:
            self.llm = _build_llm_from_env(args)
        except SystemExit as exc:
            raise RuntimeError(str(exc)) from exc
        return self.llm

    def start(self, user_request: str, exp_id: Optional[str] = None) -> AgentState:
        exp_id = exp_id or f"exp_{uuid.uuid4().hex[:8]}"
        self.state = AgentState(
            session_id=uuid.uuid4().hex[:12],
            exp_id=exp_id,
            user_request=user_request,
            request_constraints=extract_request_constraints(user_request),
            phase="init",
            created_at=int(time.time()),
        )
        self._pm = None
        self._log("state_transitions", {"event": "experiment_start", "exp_id": exp_id})
        self.state_manager.save_config(exp_id, {
            "user_request": user_request,
            "request_constraints": self.state.request_constraints,
            "created_at": self.state.created_at,
            "max_attempts_per_piece": 10,
            "llm": {
                "api_base": os.getenv("ANUM_LLM_API_BASE"),
                "model": os.getenv("ANUM_LLM_MODEL"),
                "timeout_s": None,
                "enable_thinking": str(os.getenv("ANUM_LLM_ENABLE_THINKING", "")).lower() in ("1", "true", "yes", "on"),
                "reasoning_effort": os.getenv("ANUM_LLM_REASONING_EFFORT") or DEFAULT_REASONING_EFFORT,
            },
        })
        state_machine.apply_transition(self.state, "planning")
        self._plan_with_llm()
        self._submit_all_search_pieces()
        self._save()
        return self.state

    def step(self, observation: Dict[str, Any]) -> AgentState:
        if self.state is None:
            raise RuntimeError("No state loaded.")
        self._pm = None
        reason = observation.get("supervision_reason", "")
        self._log("decisions", {
            "event": "runtime_step",
            "exp_id": self.state.exp_id,
            "supervision_reason": reason,
        })
        if reason == "heartbeat" and observation.get("task_status") == "running":
            state_machine.apply_transition(self.state, "running")
            self._save()
            return self.state
        state_machine.apply_transition(self.state, "evaluating")
        self._refinement_step()
        self._save()
        return self.state

    # ── LLM planning (used only in start) ──────────────────────────

    def _plan_with_llm(self) -> None:
        system_prompt = build_system_prompt()
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": self.state.user_request},
        ]
        for attempt in range(3):
            try:
                response_text = self._ensure_llm().generate(messages)
                action = parse_action(response_text)
            except Exception as exc:
                self._log("decisions", {"event": "llm_plan_error", "error": str(exc), "attempt": attempt})
                continue
            if action.type == ActionType.WRITE_SPEC:
                spec = action.args.get("spec")
                if isinstance(spec, dict):
                    response = self._apply_spec(spec)
                    if response.get("status") == "ok":
                        self._log("decisions", {
                            "event": "spec_planned",
                            "pieces": len(self.pm.pieces()),
                        })
                        return
                    messages.append({"role": "assistant", "content": response_text})
                    messages.append({"role": "user", "content": json.dumps({
                        "observation": response,
                        "message": "The proposed WRITE_SPEC failed validation or policy checks. Correct the spec and reply with WRITE_SPEC only.",
                    })})
                    continue
            messages.append({"role": "assistant", "content": response_text})
            messages.append({"role": "user", "content": json.dumps({
                "observation": {"status": "error", "message": "Expected WRITE_SPEC action with a valid spec."}
            })})
        raise RuntimeError("LLM failed to produce a valid spec after 3 attempts")

    def _record_math_inspection(self, spec: Dict[str, Any]) -> None:
        inspect_resp = self.tools.call("anum.math.inspect", {"spec": spec})
        self.state.history.append(
            (
                {"action": "MATH_INSPECTOR", "args": {"spec": spec}},
                inspect_resp,
            )
        )
        if inspect_resp.get("status") != "ok":
            self._log("decisions", {"event": "math_inspector_warning", "response": inspect_resp})

    def _apply_spec(self, spec: Dict[str, Any]) -> Dict[str, Any]:
        response = write_spec_action(
            state=self.state,
            tools=self.tools,
            spec=spec,
            after_apply_spec=self._record_math_inspection,
        )
        self.state.history.append(({"action": "WRITE_SPEC", "args": {"spec": spec}}, response))
        if response.get("status") == "ok":
            self.state.parallel_mode = True
            self._pm = None
        else:
            self._log("decisions", {"event": "spec_validation_failed", "response": response})
        return response

    def _latest_math_inspector_data(self) -> Optional[Dict[str, Any]]:
        if not self.state:
            return None
        for action, response in reversed(self.state.history):
            if not isinstance(action, dict) or not isinstance(response, dict):
                continue
            if action.get("action") != "MATH_INSPECTOR":
                continue
            if response.get("status") == "ok" and isinstance(response.get("data"), dict):
                return response.get("data")
        return None

    # ── Submit ─────────────────────────────────────────────────────

    def _submit_all_search_pieces(self) -> None:
        total_wn = self._total_worker_num()
        submitted = self.pm.submit_missing_pieces(total_worker_num=total_wn)
        if submitted:
            state_machine.apply_transition(self.state, "running")
            self._log("decisions", {"event": "pieces_submitted", "pieces": submitted})

    def _total_worker_num(self) -> int:
        constraints = self.state.request_constraints or {}
        rp = constraints.get("runner_params") or {}
        wn = rp.get("worker_num")
        if wn:
            try:
                return max(1, int(wn))
            except (TypeError, ValueError):
                pass
        return max(1, os.cpu_count() or 1)

    # ── Deterministic refinement ───────────────────────────────────

    def _refinement_step(self) -> None:
        polls = self.pm.poll_all()
        if not polls:
            self._submit_all_search_pieces()
            state_machine.apply_transition(self.state, "running")
            return
        stagnated = []
        all_good = True
        for poll in polls:
            pid = poll["piece_id"]
            best_error = poll.get("best_error")
            elapsed_h = poll.get("elapsed_s", 0) / 3600.0
            threshold = self.pm.metric_threshold()
            status = poll.get("status", "unknown")
            if status not in ("running", "queued"):
                all_good = False
                continue
            if best_error is None:
                all_good = False
                continue
            if threshold is not None and best_error <= threshold:
                vresult = self.pm.verify_piece(pid)
                if vresult.get("pass") and self.pm.quality_acceptable(vresult):
                    self._log("decisions", {
                        "event": "piece_verified",
                        "piece_id": pid,
                        "max_abs": vresult.get("max_abs"),
                        "max_ulp": vresult.get("max_ulp"),
                    })
                    self.pm.stop_piece(pid)
                    ps = self.state.piece_statuses.get(pid)
                    if ps:
                        ps.status = "completed"
                        ps.result = {
                            "verified_metric": vresult.get("metric_name"),
                            "verified_value": vresult.get("metric_value"),
                            "verify_threshold": vresult.get("threshold"),
                            "artifact_id": vresult.get("artifact_id"),
                        }
                    continue
            completed = poll.get("completed_tasks", 0)
            min_tasks = max(500, self._total_worker_num() * MIN_COMPLETED_TASKS_PER_WORKER)
            if completed >= min_tasks and self._is_stagnated(poll):
                stagnated.append(poll)
            all_good = False

        submitted_missing = self.pm.submit_missing_pieces(total_worker_num=self._total_worker_num())
        if submitted_missing:
            self._log("decisions", {"event": "missing_pieces_submitted", "pieces": submitted_missing})

        if stagnated:
            worst = max(stagnated, key=lambda p: p.get("best_error", 0))
            self._handle_stagnated_piece(worst)

        state_machine.apply_transition(self.state, "running")

    def _stagnation_state(self) -> Dict[str, Dict[str, Any]]:
        if not isinstance(self.state.retry_state, dict):
            self.state.retry_state = {}
        stag = self.state.retry_state.get("stagnation")
        if not isinstance(stag, dict):
            stag = {}
            self.state.retry_state["stagnation"] = stag
        return stag

    def _is_stagnated(self, poll: Dict[str, Any]) -> bool:
        pid = poll["piece_id"]
        best_error = poll.get("best_error")
        if best_error is None:
            return False
        stag = self._stagnation_state()
        entry = stag.get(pid)
        if not isinstance(entry, dict):
            entry = {"best": best_error, "count": 0}
        prev = safe_float(entry.get("best"))
        if prev is None or best_error < prev * 0.99:
            entry = {"best": best_error, "count": 0}
            stag[pid] = entry
            return False
        count = int(entry.get("count", 0)) + 1
        entry["count"] = count
        entry["best"] = prev
        stag[pid] = entry
        return count >= STAGNATION_ROUNDS_FOR_SPLIT

    def _handle_stagnated_piece(self, poll: Dict[str, Any]) -> None:
        pid = poll["piece_id"]
        interval = self.pm.piece_interval(pid)
        if interval is None:
            return
        best_error = poll.get("best_error")
        elapsed_h = poll.get("elapsed_s", 0) / 3600.0
        threshold = self.pm.metric_threshold()

        strategy_suggestion = self._suggest_for_stagnated_piece(
            piece_id=pid,
            interval=interval,
            best_error=best_error,
            threshold=threshold,
        )
        if self._apply_strategy_suggestion(pid, strategy_suggestion):
            return

        if not self.pm.can_split(pid):
            self._log("decisions", {
                "event": "split_blocked",
                "piece_id": pid,
                "reason": "cannot_split (max_pieces or min_width)",
            })
            return

        decision = self._ask_llm_split_decision(
            piece_id=pid,
            interval=interval,
            best_error=best_error,
            elapsed_h=elapsed_h,
            threshold=threshold,
            current_pieces=len(self.pm.pieces()),
            max_pieces=self.pm.max_pieces(),
            math_context=self._latest_math_inspector_data(),
            strategy_suggestion=strategy_suggestion,
        )

        if decision.get("action") != "split":
            self._log("decisions", {
                "event": "llm_decided_no_split",
                "piece_id": pid,
                "decision": decision,
            })
            return

        split_point = safe_float(decision.get("split_point"))
        self._log("decisions", {
            "event": "splitting_piece",
            "piece_id": pid,
            "best_error": best_error,
            "elapsed_h": elapsed_h,
            "split_point": split_point,
            "llm_reason": decision.get("reason", ""),
        })

        self.pm.stop_piece(pid)
        result = self.pm.split_piece(pid, split_point=split_point)
        if result is None:
            self._log("decisions", {"event": "split_failed", "piece_id": pid})
            return
        new_id_0, new_id_1 = result
        total_wn = self._total_worker_num()
        active_count = sum(
            1 for p in self.pm.search_piece_ids()
            if (self.state.piece_statuses.get(p) is not None and self.state.piece_statuses[p].status == "running")
        ) + 2  # +2 for the two new pieces about to be submitted
        per_piece = max(1, total_wn // active_count)
        self.pm.submit_piece(new_id_0, worker_num=per_piece)
        self.pm.submit_piece(new_id_1, worker_num=per_piece)
        self._log("decisions", {
            "event": "split_done",
            "old_piece": pid,
            "new_pieces": [new_id_0, new_id_1],
        })

    def _apply_strategy_suggestion(
        self,
        piece_id: str,
        strategy_suggestion: Optional[Dict[str, Any]],
    ) -> bool:
        if not isinstance(strategy_suggestion, dict):
            return False
        decision = str(strategy_suggestion.get("decision") or "").strip().lower()
        if decision != "map_reuse":
            return False
        patch = strategy_suggestion.get("spec_patch")
        if not isinstance(patch, dict):
            self._log("decisions", {
                "event": "strategy_suggestion_ignored",
                "piece_id": piece_id,
                "reason": "map_reuse missing spec_patch",
            })
            return False

        response = update_spec_action(
            state=self.state,
            tools=self.tools,
            patch=patch,
            after_apply_spec=self._record_math_inspection,
        )
        self.state.history.append((
            {"action": "UPDATE_SPEC", "args": {"patch": patch, "source": "SUGGEST_STRATEGY"}},
            response,
        ))
        if response.get("status") != "ok":
            self._log("decisions", {
                "event": "strategy_patch_failed",
                "piece_id": piece_id,
                "response": response,
            })
            return False

        old_status = self.state.piece_statuses.get(piece_id)
        old_task_tag = old_status.task_tag if old_status is not None else None
        if old_task_tag:
            self.pm.stop_piece(piece_id)
            self.state.pending_callbacks = [
                tag for tag in self.state.pending_callbacks if tag != old_task_tag
            ]
        self.state.piece_statuses.pop(piece_id, None)
        self._pm = None
        self._save()
        self._log("decisions", {
            "event": "strategy_patch_applied",
            "piece_id": piece_id,
            "decision": decision,
            "notes": strategy_suggestion.get("notes"),
        })
        return True

    def _suggest_for_stagnated_piece(
        self,
        *,
        piece_id: str,
        interval: tuple,
        best_error: Optional[float],
        threshold: Optional[float],
    ) -> Optional[Dict[str, Any]]:
        start, end = interval
        mid = 0.5 * (float(start) + float(end))
        verify = {
            "pass": False,
            "metric_name": (self.pm.spec().get("metric") or {}).get("type"),
            "metric_value": best_error,
            "threshold": threshold,
            "failure_modes": ["stagnation"],
            "counterexamples": [{"x": mid, "source": "stagnation_midpoint"}],
        }
        attempt = int((self.state.piece_attempts or {}).get(piece_id, 0)) if self.state else 0
        request = {
            "spec": self.pm.spec(),
            "piece_id": piece_id,
            "verify": verify,
            "attempt": attempt,
            "split_policy": "counterexample",
        }
        response = self.tools.call("anum.strategy.suggest", request)
        self.state.history.append(({"action": "SUGGEST_STRATEGY", "args": request}, response))
        self._log("decisions", {
            "event": "strategy_suggested",
            "piece_id": piece_id,
            "response": response,
        })
        if response.get("status") == "ok" and isinstance(response.get("data"), dict):
            return response.get("data")
        return None

    @staticmethod
    def _split_point_from_strategy_suggestion(
        strategy_suggestion: Optional[Dict[str, Any]],
        interval: tuple,
    ) -> Optional[float]:
        if not isinstance(strategy_suggestion, dict):
            return None
        if strategy_suggestion.get("decision") != "split":
            return None
        patch = strategy_suggestion.get("spec_patch")
        pieces = ((patch or {}).get("domain") or {}).get("pieces") if isinstance(patch, dict) else None
        if not isinstance(pieces, list):
            return None
        start, end = float(interval[0]), float(interval[1])
        segments = []
        for piece in pieces:
            if not isinstance(piece, dict):
                continue
            piece_interval = piece.get("interval")
            if not isinstance(piece_interval, dict):
                continue
            seg_start = safe_float(piece_interval.get("start"))
            seg_end = safe_float(piece_interval.get("end"))
            if seg_start is None or seg_end is None:
                continue
            if start <= seg_start < seg_end <= end:
                segments.append((seg_start, seg_end))
        segments = sorted(set(segments))
        if len(segments) < 2:
            return None
        split_point = segments[0][1]
        if start < split_point < end:
            return split_point
        return None

    def _ask_llm_split_decision(
        self,
        *,
        piece_id: str,
        interval: tuple,
        best_error: Optional[float],
        elapsed_h: float,
        threshold: Optional[float],
        current_pieces: int,
        max_pieces: int,
        math_context: Optional[Dict[str, Any]] = None,
        strategy_suggestion: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        target_name = (self.pm.spec().get("target") or {}).get("name", "unknown")
        best_error_text = f"{best_error:.3e}" if best_error is not None else "unknown"
        threshold_text = f"{threshold:.3e}" if threshold is not None else "not specified"
        error_gap = (
            f"{best_error / threshold:.1f}x"
            if best_error is not None and threshold not in (None, 0)
            else "unknown"
        )
        math_summary = json.dumps(math_context or {}, ensure_ascii=True, sort_keys=True)[:2000]
        strategy_summary = json.dumps(strategy_suggestion or {}, ensure_ascii=True, sort_keys=True)[:2000]
        prompt = (
            f"You are optimizing a numerical approximation for {target_name}.\n"
            f"Piece '{piece_id}' covers [{interval[0]}, {interval[1]}] and has stagnated.\n"
            f"Current best error: {best_error_text} (target: {threshold_text}).\n"
            f"Search has run for {elapsed_h:.1f} hours.\n"
            f"Current pieces: {current_pieces}/{max_pieces} max.\n\n"
            f"MATH_INSPECTOR context:\n{math_summary}\n\n"
            f"SUGGEST_STRATEGY recommendation:\n{strategy_summary}\n\n"
            f"Should this piece be split into two sub-intervals?\n"
            f"Consider:\n"
            f"- Function behavior in this interval (singularities, inflection points, rapid changes)\n"
            f"- Whether the error gap ({error_gap}) justifies splitting\n"
            f"- Remaining piece budget ({max_pieces - current_pieces} slots left)\n\n"
            f"Reply with JSON: {{\"action\": \"split\" or \"continue\", \"split_point\": <float or null>, \"reason\": \"...\"}}\n"
            f"If splitting, choose split_point based on function structure, not just midpoint."
        )
        try:
            response_text = self._ensure_llm().generate([
                {"role": "system", "content": "You are a numerical analysis expert. Reply only with JSON."},
                {"role": "user", "content": prompt},
            ])
            from agent.actions import _extract_json
            return _extract_json(response_text)
        except Exception as exc:
            self._log("decisions", {"event": "llm_split_error", "error": str(exc)})
            split_point = self._split_point_from_strategy_suggestion(strategy_suggestion, interval)
            if split_point is not None:
                return {
                    "action": "split",
                    "split_point": split_point,
                    "reason": "LLM unavailable; using SUGGEST_STRATEGY split recommendation.",
                }
            return {"action": "continue", "split_point": None, "reason": "LLM unavailable, default continue"}

    # ── Compat shim ────────────────────────────────────────────────

    def _run_steps(self) -> None:
        self._refinement_step()
        self._save()
