"""Piece lifecycle manager — track, split, submit, evaluate pieces."""
from __future__ import annotations

import copy
import math
import time
from typing import Any, Dict, List, Optional, Tuple

from agent.state import AgentState, PieceStatus
from agent.state_manager import StateManager
from agent.tooling import ToolClient
from server.run_manager import make_response


def safe_float(value: Any) -> Optional[float]:
    try:
        v = float(value)
    except (TypeError, ValueError):
        return None
    return v if math.isfinite(v) else None


class PieceManager:
    """Manages piece lifecycle: submit, poll, evaluate, split."""

    def __init__(
        self,
        state: AgentState,
        state_manager: StateManager,
        tools: ToolClient,
    ):
        self.state = state
        self.state_manager = state_manager
        self.tools = tools

    def save(self) -> None:
        if self.state and self.state.exp_id:
            self.state.updated_at = int(time.time())
            self.state_manager.save(self.state.exp_id, self.state)

    # ── Query helpers ──────────────────────────────────────────────

    def spec(self) -> Dict[str, Any]:
        return self.state.current_spec or {}

    def pieces(self) -> List[Dict[str, Any]]:
        return (self.spec().get("domain") or {}).get("pieces") or []

    def search_piece_ids(self) -> List[str]:
        out = []
        for piece in self.pieces():
            mode = (piece.get("strategy") or {}).get("mode", "search")
            if mode in ("search", "auto", "dag"):
                out.append(str(piece.get("piece_id", "")))
        return out

    def metric_threshold(self) -> Optional[float]:
        metric = self.spec().get("metric") or {}
        return safe_float(metric.get("threshold"))

    def max_pieces(self) -> int:
        return int((self.spec().get("stop_criteria") or {}).get("max_pieces", 20))

    def min_piece_width(self) -> float:
        v = safe_float((self.spec().get("stop_criteria") or {}).get("min_piece_width"))
        return v if v and v > 0 else 0.01

    def piece_interval(self, piece_id: str) -> Optional[Tuple[float, float]]:
        for piece in self.pieces():
            if str(piece.get("piece_id")) == piece_id:
                iv = piece.get("interval") or {}
                start = safe_float(iv.get("start"))
                end = safe_float(iv.get("end"))
                if start is not None and end is not None and start < end:
                    return (start, end)
        return None

    # ── Poll ───────────────────────────────────────────────────────

    def poll_piece(self, piece_id: str) -> Dict[str, Any]:
        ps = self.state.piece_statuses.get(piece_id)
        if not ps or not ps.task_tag:
            return {"status": "no_task", "piece_id": piece_id}
        resp = self.tools.call("anum.run.poll", {"task_tag": ps.task_tag})
        data = (resp.get("data") or {}) if resp.get("status") == "ok" else {}
        bs = data.get("best_summary") or {}
        return {
            "piece_id": piece_id,
            "status": data.get("status", "unknown"),
            "best_error": safe_float(bs.get("error")),
            "ops": bs.get("ops"),
            "elapsed_s": safe_float(data.get("elapsed_s")) or 0.0,
            "completed_tasks": (data.get("progress") or {}).get("completed_tasks", 0),
            "task_tag": ps.task_tag,
        }

    def poll_all(self) -> List[Dict[str, Any]]:
        results = []
        for pid in self.search_piece_ids():
            results.append(self.poll_piece(pid))
        return results

    # ── Verify ─────────────────────────────────────────────────────

    def verify_piece(self, piece_id: str) -> Dict[str, Any]:
        ps = self.state.piece_statuses.get(piece_id)
        if not ps or not ps.task_tag:
            return {"pass": False, "reason": "no_task"}
        result_resp = self.tools.call("anum.run.result", {"task_tag": ps.task_tag})
        if result_resp.get("status") != "ok":
            return {"pass": False, "reason": "no_result"}
        best = (result_resp.get("data") or {}).get("best_candidate") or {}
        artifact_id = best.get("artifact_id")
        if not artifact_id:
            return {"pass": False, "reason": "no_artifact"}
        verify_resp = self.tools.call("anum.verify.evaluate", {
            "spec": self.spec(),
            "candidate_artifact_id": artifact_id,
            "task_tag": ps.task_tag,
            "piece_id": piece_id,
            "level": 2,
        })
        if verify_resp.get("status") != "ok":
            return {"pass": False, "reason": "verify_error"}
        vdata = verify_resp.get("data") or {}
        metrics = vdata.get("metrics") or {}
        return {
            "pass": vdata.get("pass", False),
            "max_abs": safe_float(metrics.get("max_abs")),
            "max_ulp": safe_float(metrics.get("max_ulp")),
            "p99_ulp": safe_float(metrics.get("p99_ulp")),
            "metric_name": vdata.get("metric_name"),
            "metric_value": safe_float(metrics.get(vdata.get("metric_name"))),
            "threshold": safe_float(vdata.get("threshold")),
            "artifact_id": artifact_id,
            "metrics": metrics,
        }

    def quality_acceptable(self, verify_result: Dict[str, Any]) -> bool:
        return bool(verify_result.get("pass", False))

    # ── Submit ─────────────────────────────────────────────────────

    def runner_params_for_piece(self, worker_num: int) -> Dict[str, Any]:
        constraints = self.state.request_constraints or {}
        configured = constraints.get("runner_params") or {}
        params = dict(configured) if isinstance(configured, dict) else {}
        per_piece_worker_num = params.pop("per_piece_worker_num", None)
        if per_piece_worker_num:
            try:
                worker_num = max(1, int(per_piece_worker_num))
            except (TypeError, ValueError):
                pass
        params["worker_num"] = max(1, int(worker_num))
        return params

    def submit_piece(self, piece_id: str, worker_num: int = 0) -> Dict[str, Any]:
        piece = None
        for p in self.pieces():
            if str(p.get("piece_id")) == piece_id:
                piece = p
                break
        if piece is None:
            return make_response("error", errors=[{
                "code": "piece_not_found",
                "message": f"piece_id '{piece_id}' not in spec",
            }])
        task_tag = f"{self.state.exp_id}_piece_{piece_id}"
        piece_spec = copy.deepcopy(self.spec())
        piece_spec["domain"] = {"pieces": [piece]}
        if not worker_num:
            total_workers = max(1, __import__("os").cpu_count() or 1)
            n_search = max(1, len(self.search_piece_ids()))
            worker_num = max(1, total_workers // n_search)
        runner = {"type": "local", "params": self.runner_params_for_piece(worker_num)}
        request_id = f"{self.state.exp_id}_{piece_id}_{int(time.time())}"
        resp = self.tools.call("anum.run.submit", {
            "spec": piece_spec,
            "task_tag": task_tag,
            "runner": runner,
            "request_id": request_id,
        })
        if resp.get("status") == "ok":
            self.state.piece_statuses[piece_id] = PieceStatus(
                piece_id=piece_id, task_tag=task_tag, status="running",
            )
            if task_tag not in self.state.pending_callbacks:
                self.state.pending_callbacks.append(task_tag)
            self.save()
        return resp

    def submit_missing_pieces(self, total_worker_num: int = 0) -> List[str]:
        if not total_worker_num:
            total_worker_num = max(1, __import__("os").cpu_count() or 1)
        missing = []
        for pid in self.search_piece_ids():
            ps = self.state.piece_statuses.get(pid)
            if ps and ps.task_tag:
                status = str(ps.status or "").strip().lower()
                if status in ("pending", "queued", "running", "completed"):
                    continue
            missing.append(pid)
        if not missing:
            return []
        # Divide budget across ALL active search pieces (running + missing)
        # so that newly submitted pieces don't over-subscribe CPU cores.
        total_search = max(len(missing), len(self.search_piece_ids()))
        per_piece = max(1, total_worker_num // total_search)
        submitted = []
        for pid in missing:
            resp = self.submit_piece(pid, worker_num=per_piece)
            if resp.get("status") == "ok":
                submitted.append(pid)
        return submitted

    # ── Stop ───────────────────────────────────────────────────────

    def stop_piece(self, piece_id: str) -> bool:
        ps = self.state.piece_statuses.get(piece_id)
        if not ps or not ps.task_tag:
            return False
        resp = self.tools.call("anum.run.stop", {"task_tag": ps.task_tag})
        if resp.get("status") == "ok":
            self.state.pending_callbacks = [
                t for t in self.state.pending_callbacks if t != ps.task_tag
            ]
            ps.status = "stopped"
            self.save()
            return True
        return False

    # ── Split ──────────────────────────────────────────────────────

    def can_split(self, piece_id: str) -> bool:
        if len(self.pieces()) >= self.max_pieces():
            return False
        interval = self.piece_interval(piece_id)
        if interval is None:
            return False
        width = interval[1] - interval[0]
        return width >= self.min_piece_width() * 2.0

    @staticmethod
    def _split_piece_payload(
        source: Dict[str, Any],
        piece_id: str,
        start: float,
        end: float,
    ) -> Dict[str, Any]:
        payload: Dict[str, Any] = {
            "piece_id": piece_id,
            "interval": {"start": start, "end": end},
        }
        excluded = source.get("excluded_points")
        if isinstance(excluded, list):
            kept = []
            for value in excluded:
                point = safe_float(value)
                if point is not None and start <= point <= end:
                    kept.append(point)
            if kept:
                payload["excluded_points"] = kept
        for key in ("transform", "strategy"):
            if key in source:
                payload[key] = copy.deepcopy(source.get(key))
        payload.setdefault("strategy", {"mode": "search"})
        return payload

    def split_piece(self, piece_id: str, split_point: Optional[float] = None) -> Optional[Tuple[str, str]]:
        if not self.can_split(piece_id):
            return None
        interval = self.piece_interval(piece_id)
        if interval is None:
            return None
        start, end = interval
        if split_point is None or split_point <= start or split_point >= end:
            split_point = (start + end) / 2.0
        new_id_0 = f"{piece_id}_0"
        new_id_1 = f"{piece_id}_1"
        existing_ids = {str(p.get("piece_id")) for p in self.pieces()}
        suffix = 0
        while new_id_0 in existing_ids or new_id_1 in existing_ids:
            suffix += 1
            new_id_0 = f"{piece_id}_0_{suffix}"
            new_id_1 = f"{piece_id}_1_{suffix}"

        old_pieces = self.pieces()
        new_pieces = []
        for p in old_pieces:
            if str(p.get("piece_id")) == piece_id:
                new_pieces.append(self._split_piece_payload(p, new_id_0, start, split_point))
                new_pieces.append(self._split_piece_payload(p, new_id_1, split_point, end))
            else:
                new_pieces.append(p)

        spec = self.spec()
        spec["domain"]["pieces"] = new_pieces
        self.state.piece_statuses.pop(piece_id, None)
        self.state.pending_callbacks = [
            t for t in self.state.pending_callbacks
            if t != f"{self.state.exp_id}_piece_{piece_id}"
        ]
        self.save()
        return (new_id_0, new_id_1)
