"""Job Dispatcher - Submits search tasks and tracks active task tags."""
from __future__ import annotations

import os
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING

from agent.state import AgentState, PieceStatus
from agent import state_machine
from agent.state_manager import StateManager
from server.run_manager import RunManager, make_response
if TYPE_CHECKING:  # pragma: no cover - typing only
    from agent.tooling import ToolClient


class JobDispatcher:
    """Dispatches search jobs for the unified local poller-driven workflow."""

    def __init__(
        self,
        state_manager: StateManager,
        run_manager: Optional[RunManager] = None,
        repo_root: Optional[str] = None,
        tool_client: Optional["ToolClient"] = None,
    ):
        self.state_manager = state_manager
        self.run_manager = run_manager or RunManager()
        self.repo_root = repo_root or self.run_manager.repo_root
        self.tool_client = tool_client

    @staticmethod
    def _cpu_worker_num() -> int:
        return max(1, os.cpu_count() or 1)

    @staticmethod
    def _truthy_env(name: str) -> bool:
        return os.getenv(name, "").strip().lower() in ("1", "true", "yes", "on")

    def _normalize_worker_num(self, params: Dict[str, Any]) -> None:
        """Normalize runner worker_num so agent-side inputs can safely express auto mode.

        Rules:
        - If ANUM_FORCE_WORKER_NUM_FROM_CPU is true, always use host CPU core count.
        - If worker_num is missing / <= 0 / invalid, fallback to host CPU core count.
        - String aliases {"auto","cpu","cores","host_cpu"} map to host CPU core count.
        """
        cpu_workers = self._cpu_worker_num()
        if self._truthy_env("ANUM_FORCE_WORKER_NUM_FROM_CPU"):
            params["worker_num"] = cpu_workers
            return

        raw = params.get("worker_num")
        if raw is None:
            params["worker_num"] = cpu_workers
            return

        if isinstance(raw, str):
            token = raw.strip().lower()
            if token in ("auto", "cpu", "cores", "host_cpu"):
                params["worker_num"] = cpu_workers
                return
        try:
            value = int(raw)
        except (TypeError, ValueError):
            params["worker_num"] = cpu_workers
            return
        if value <= 0:
            params["worker_num"] = cpu_workers
            return
        params["worker_num"] = value

    def _resolve_total_worker_budget(self, runner_params: Optional[Dict[str, Any]]) -> int:
        params = runner_params.copy() if runner_params else {}
        self._normalize_worker_num(params)
        try:
            total = int(params.get("worker_num", self._cpu_worker_num()))
        except (TypeError, ValueError):
            total = self._cpu_worker_num()
        return max(1, total)

    @staticmethod
    def _safe_int(value: Any) -> Optional[int]:
        try:
            return int(value)
        except (TypeError, ValueError):
            return None

    @classmethod
    def _extract_piece_worker_overrides(cls, runner_params: Optional[Dict[str, Any]]) -> Dict[str, int]:
        if not isinstance(runner_params, dict):
            return {}
        merged: Dict[str, Any] = {}
        direct = runner_params.get("per_piece_worker_num")
        if isinstance(direct, dict):
            merged.update(direct)
        worker_plan = runner_params.get("worker_plan")
        if isinstance(worker_plan, dict):
            per_piece = worker_plan.get("per_piece")
            if isinstance(per_piece, dict):
                merged.update(per_piece)
        overrides: Dict[str, int] = {}
        for piece_id, raw in merged.items():
            parsed = cls._safe_int(raw)
            if parsed is None or parsed <= 0:
                continue
            overrides[str(piece_id)] = parsed
        return overrides

    @staticmethod
    def _allocate_workers(total_workers: int, piece_count: int) -> List[int]:
        if piece_count <= 0:
            return []
        total = max(1, int(total_workers))
        if total < piece_count:
            return [1] * piece_count
        base = total // piece_count
        remainder = total % piece_count
        allocation = []
        for idx in range(piece_count):
            allocation.append(base + (1 if idx < remainder else 0))
        return allocation

    @staticmethod
    def _build_request_id(base: Optional[str], suffix: Optional[str] = None) -> str:
        token = str(base or "req").strip() or "req"
        token = token.replace(" ", "_")
        if suffix:
            suffix_token = str(suffix).strip().replace(" ", "_")
            token = f"{token}_{suffix_token}"
        return f"{token}_{int(time.time())}_{uuid.uuid4().hex[:8]}"

    @staticmethod
    def _has_error_code(response: Dict[str, Any], code: str) -> bool:
        if response.get("status") != "error":
            return False
        errors = response.get("errors") or []
        return any(str(item.get("code")) == code for item in errors if isinstance(item, dict))

    def submit_search(
        self,
        exp_id: str,
        spec: Dict[str, Any],
        task_tag: Optional[str] = None,
        piece_id: Optional[str] = None,
        runner_params: Optional[Dict[str, Any]] = None,
        request_id: Optional[str] = None,
    ) -> Dict[str, Any]:
        """Submit a search task in the unified poller-driven workflow.

        Args:
            exp_id: Experiment ID for state persistence
            spec: Approximation spec defining the search
            task_tag: Optional custom task tag (auto-generated if not provided)
            piece_id: Optional piece ID for parallel multi-piece execution
            runner_params: Additional runner parameters (worker_num, max_tasks, etc.)
            request_id: Optional request_id for idempotent tracking

        Returns:
            Response dict with:
                - status: "ok" or "error"
                - data: {run_id, task_tag, artifact_id}
        """
        params = runner_params.copy() if runner_params else {}
        self._normalize_worker_num(params)

        # Generate task tag if needed (for parallel pieces)
        if not task_tag and piece_id:
            task_tag = f"{exp_id}_piece_{piece_id}"

        # Submit via RunManager
        runner = {"type": "local", "params": params}
        def _submit_once(req_id: Optional[str]) -> Dict[str, Any]:
            payload = {
                "spec": spec,
                "task_tag": task_tag,
                "runner": runner,
                "request_id": req_id,
            }
            if self.tool_client:
                return self.tool_client.call("anum.run.submit", payload)
            return self.run_manager.submit(
                spec=spec,
                task_tag=task_tag,
                runner=runner,
                request_id=req_id,
            )

        response = _submit_once(request_id)
        if self._has_error_code(response, "idempotency_conflict"):
            retry_request_id = self._build_request_id(
                request_id or spec.get("request_id") or task_tag or exp_id,
                suffix=(piece_id or task_tag),
            )
            retry_response = _submit_once(retry_request_id)
            if retry_response.get("status") == "ok":
                merged = dict(retry_response)
                warnings = list(merged.get("warnings") or [])
                warnings.append(
                    {
                        "code": "request_id_refreshed",
                        "message": "Resubmitted after idempotency conflict with a refreshed request_id.",
                        "details": {"old_request_id": request_id, "new_request_id": retry_request_id},
                    }
                )
                merged["warnings"] = warnings
                response = merged

        return response

    def submit_parallel_search(
        self,
        exp_id: str,
        state: AgentState,
        spec: Dict[str, Any],
        runner_params: Optional[Dict[str, Any]] = None,
        piece_ids: Optional[List[str]] = None,
    ) -> Dict[str, Any]:
        """Submit parallel search tasks for all domain pieces.

        Args:
            exp_id: Experiment ID
            state: AgentState (will be updated with piece statuses)
            spec: Approximation spec with multiple pieces
            runner_params: Runner parameters for each search

        Returns:
            Response dict with submitted task information
        """
        pieces = spec.get("domain", {}).get("pieces", [])
        if not pieces:
            return make_response(
                "error",
                errors=[{"code": "no_pieces", "message": "No domain pieces found in spec"}],
            )

        def _is_search_piece(piece: Dict[str, Any]) -> bool:
            strategy = piece.get("strategy")
            if isinstance(strategy, dict):
                mode = str(strategy.get("mode", "search")).lower()
                if mode not in ("search", "auto", "dag"):
                    return False
            return True

        search_entries: List[Tuple[str, Dict[str, Any]]] = []
        skipped: List[str] = []
        skipped_details: List[Dict[str, Any]] = []
        for idx, piece in enumerate(pieces):
            piece_id = str(piece.get("piece_id", idx))
            if _is_search_piece(piece):
                search_entries.append((piece_id, piece))
            else:
                skipped.append(piece_id)
                mode = "unknown"
                strategy = piece.get("strategy")
                if isinstance(strategy, dict):
                    mode = str(strategy.get("mode", "search")).lower()
                skipped_details.append(
                    {
                        "piece_id": piece_id,
                        "mode": mode,
                        "reason": "non_search_strategy",
                    }
                )

        missing_piece_ids: List[str] = []
        if piece_ids is not None:
            requested = [str(item) for item in piece_ids if str(item).strip()]
            requested_set = set(requested)
            available_ids = {piece_id for piece_id, _ in search_entries}
            missing_piece_ids = [piece_id for piece_id in requested if piece_id not in available_ids]
            search_entries = [
                (piece_id, piece)
                for piece_id, piece in search_entries
                if piece_id in requested_set
            ]

        if not search_entries:
            if piece_ids is not None:
                return make_response(
                    "error",
                    errors=[
                        {
                            "code": "no_matching_search_pieces",
                            "message": "No search pieces matched requested piece_ids.",
                            "details": {"piece_ids": [str(item) for item in piece_ids]},
                        }
                    ],
                )
            return make_response(
                "error",
                errors=[
                    {
                        "code": "no_search_pieces",
                        "message": "All domain pieces are marked as non-search strategies.",
                        "details": {"skipped": skipped},
                    }
                ],
            )

        state.parallel_mode = True
        if piece_ids is None:
            state.piece_statuses = {}

        total_workers = self._resolve_total_worker_budget(runner_params)
        worker_allocations = self._allocate_workers(total_workers, len(search_entries))
        piece_overrides = self._extract_piece_worker_overrides(runner_params)
        if piece_overrides and "worker_num" not in (runner_params or {}):
            total_workers = max(total_workers, sum(piece_overrides.values()))
        submitted = []
        errors = []
        actual_allocations: Dict[str, int] = {}

        for idx, (piece_id, piece) in enumerate(search_entries):
            task_tag = f"{exp_id}_piece_{piece_id}"
            piece_request_id = self._build_request_id(spec.get("request_id") or exp_id, suffix=str(piece_id))

            # Create piece-specific spec (single piece)
            piece_spec = spec.copy()
            piece_spec["domain"] = {"pieces": [piece]}
            piece_runner_params = runner_params.copy() if runner_params else {}
            piece_runner_params.pop("per_piece_worker_num", None)
            piece_runner_params.pop("worker_plan", None)
            allocation = piece_overrides.get(piece_id, worker_allocations[idx])
            piece_runner_params["worker_num"] = max(1, int(allocation))
            actual_allocations[piece_id] = piece_runner_params["worker_num"]

            # Submit search for this piece
            response = self.submit_search(
                exp_id=exp_id,
                spec=piece_spec,
                task_tag=task_tag,
                piece_id=piece_id,
                runner_params=piece_runner_params,
                request_id=piece_request_id,
            )

            if response.get("status") == "ok":
                state.piece_statuses[piece_id] = PieceStatus(
                    piece_id=piece_id,
                    task_tag=task_tag,
                    status="running",
                )
                if task_tag not in state.pending_callbacks:
                    state.pending_callbacks.append(task_tag)
                submitted.append({
                    "piece_id": piece_id,
                    "task_tag": task_tag,
                    "run_id": response.get("data", {}).get("run_id"),
                })
            else:
                errors.append({
                    "piece_id": piece_id,
                    "errors": response.get("errors", []),
                })

        # Save updated state
        self.state_manager.save(exp_id, state)

        if errors and not submitted:
            return make_response(
                "error",
                errors=[{"code": "all_failed", "message": "All piece submissions failed", "details": errors}],
            )

        data = {
            "submitted": submitted,
            "failed": errors,
            "total_pieces": len(search_entries),
            "worker_allocation": {
                "total_workers": total_workers,
                "per_piece": actual_allocations,
            },
        }
        if piece_ids is not None:
            data["requested_piece_ids"] = [str(item) for item in piece_ids]
        if skipped:
            data["skipped"] = skipped
            data["skipped_details"] = skipped_details
        warnings = []
        if errors:
            warnings.append({"code": "partial_failure", "message": f"{len(errors)} pieces failed"})
        if missing_piece_ids:
            warnings.append(
                {
                    "code": "piece_ids_not_found",
                    "message": "Some requested piece_ids were not found in search pieces.",
                    "details": {"piece_ids": missing_piece_ids},
                }
            )
        if total_workers < len(search_entries):
            warnings.append(
                {
                    "code": "worker_overcommit",
                    "message": "Piece count exceeds worker budget; each piece was assigned at least one worker.",
                    "details": {
                        "total_workers": total_workers,
                        "piece_count": len(search_entries),
                    },
                }
            )
        return make_response(
            "ok",
            data=data,
            warnings=warnings,
        )

    def poll_all_pieces(self, exp_id: str) -> Dict[str, Any]:
        """Poll status of all pieces in a parallel search.

        Args:
            exp_id: Experiment ID

        Returns:
            Aggregated status of all pieces
        """
        state = self.state_manager.load(exp_id)

        statuses = []
        for piece_id, piece_status in state.piece_statuses.items():
            if self.tool_client:
                poll_response = self.tool_client.call(
                    "anum.run.poll",
                    {"task_tag": piece_status.task_tag},
                )
            else:
                poll_response = self.run_manager.poll(piece_status.task_tag)
            poll_data = poll_response.get("data", {})

            statuses.append({
                "piece_id": piece_id,
                "task_tag": piece_status.task_tag,
                "status": piece_status.status,
                "poll_status": poll_data.get("status"),
                "progress": poll_data.get("progress"),
                "best_summary": poll_data.get("best_summary"),
            })

        all_done = state.all_pieces_done()
        any_failed = state.any_piece_failed()

        return make_response(
            "ok",
            data={
                "pieces": statuses,
                "all_done": all_done,
                "any_failed": any_failed,
                "parallel_mode": state.parallel_mode,
            },
        )

    def stop_all_pieces(self, exp_id: str) -> Dict[str, Any]:
        """Stop active runs for an experiment (parallel or single-run).

        Args:
            exp_id: Experiment ID

        Returns:
            Stop response for active task tags
        """
        state = self.state_manager.load(exp_id)

        stopped = []
        errors = []
        targets: List[Dict[str, Any]] = []
        active_run_statuses = {"queued", "running", "pending", "submitted"}

        for piece_id, piece_status in state.piece_statuses.items():
            if not piece_status.task_tag:
                continue
            if str(piece_status.status).lower() in ("pending", "running"):
                targets.append({"task_tag": piece_status.task_tag, "piece_id": piece_id})

        for task_tag in state.pending_callbacks:
            if task_tag:
                targets.append({"task_tag": task_tag, "piece_id": None})

        for run in state.runs:
            if not run.task_tag:
                continue
            if str(run.status or "").lower() in active_run_statuses:
                targets.append({"task_tag": run.task_tag, "piece_id": None})

        dedup_targets: List[Dict[str, Any]] = []
        seen = set()
        for item in targets:
            task_tag = item.get("task_tag")
            if not task_tag or task_tag in seen:
                continue
            seen.add(task_tag)
            dedup_targets.append(item)

        for target in dedup_targets:
            task_tag = target["task_tag"]
            piece_id = target.get("piece_id")
            if self.tool_client:
                response = self.tool_client.call(
                    "anum.run.stop",
                    {"task_tag": task_tag},
                )
            else:
                response = self.run_manager.stop(task_tag)
            if response.get("status") == "ok":
                stopped.append(task_tag)
                state.pending_callbacks = [tag for tag in state.pending_callbacks if tag != task_tag]
                for run in state.runs:
                    if run.task_tag == task_tag:
                        run.status = "stopped"
                for p_status in state.piece_statuses.values():
                    if p_status.task_tag == task_tag and str(p_status.status).lower() in ("pending", "running"):
                        p_status.status = "stopped"
            else:
                errors.append({
                    "task_tag": task_tag,
                    "piece_id": piece_id,
                    "errors": response.get("errors", []),
                })

        if not state.pending_callbacks:
            active_runs = {
                str(run.status or "").lower()
                for run in state.runs
                if str(run.status or "").lower() in active_run_statuses
            }
            if not active_runs:
                state_machine.apply_transition(state, state_machine.PHASE_STOPPED)

        self.state_manager.save(exp_id, state)

        return make_response(
            "ok",
            data={
                "stopped": stopped,
                "errors": errors,
                "targets": [item.get("task_tag") for item in dedup_targets],
            },
        )
