from __future__ import annotations

import argparse
import math
import os
import time
from typing import Any, Dict, List, Optional, Tuple

from agent.piece_focus import build_repartition_focus
from agent.resume import resume_experiment
from agent.state_manager import StateManager
from agent.trace import TraceLogger, wrap_tool_client
from agent.tooling import build_internal_tool_client


def _find_piece_id(state, task_tag: str) -> Optional[str]:
    for piece_id, piece in state.piece_statuses.items():
        if piece.task_tag == task_tag:
            return piece_id
    return None


def _safe_float(value: Any) -> Optional[float]:
    try:
        number = float(value)
    except (TypeError, ValueError):
        return None
    if not math.isfinite(number):
        return None
    return number


def _safe_int(value: Any) -> Optional[int]:
    try:
        return int(value)
    except (TypeError, ValueError):
        return None


def _extract_best_error(poll_response: Dict[str, Any]) -> Optional[float]:
    best = poll_response.get("data", {}).get("best_summary") or {}
    return _safe_float(best.get("error"))


def _extract_target_threshold(state) -> Optional[float]:
    spec = getattr(state, "current_spec", None) or {}
    metric = spec.get("metric") or {}
    metric_type = str(metric.get("type", "")).strip().lower()
    if metric_type == "ulp":
        return None
    return _safe_float(metric.get("threshold"))


def _extract_piece_timeout_s(state, override_timeout_s: int) -> Optional[float]:
    if override_timeout_s and override_timeout_s > 0:
        return float(override_timeout_s)
    spec = getattr(state, "current_spec", None) or {}
    stop_criteria = spec.get("stop_criteria") or {}
    timeout_s = _safe_float(stop_criteria.get("max_wall_time_s"))
    if timeout_s is None or timeout_s <= 0:
        return None
    return timeout_s


def _default_stagnation_delta(threshold: Optional[float]) -> float:
    if threshold is None:
        return 1e-8
    return max(1e-8, 0.05 * threshold)


def _load_supervision_state(state_manager: Any, exp_id: str) -> Dict[str, Any]:
    loader = getattr(state_manager, "load_supervision", None)
    if callable(loader):
        try:
            payload = loader(exp_id)
            if isinstance(payload, dict):
                tasks = payload.get("tasks")
                if isinstance(tasks, dict):
                    global_state = payload.get("global")
                    if not isinstance(global_state, dict):
                        global_state = {}
                    return {"version": 1, "tasks": tasks, "global": global_state}
        except Exception:
            pass
    return {"version": 1, "tasks": {}, "global": {}}


def _save_supervision_state(state_manager: Any, exp_id: str, payload: Dict[str, Any]) -> None:
    saver = getattr(state_manager, "save_supervision", None)
    if callable(saver):
        try:
            saver(exp_id, payload)
        except Exception:
            pass


def _task_supervision_entry(tasks: Dict[str, Any], task_tag: str, now: float) -> Dict[str, Any]:
    entry = tasks.get(task_tag)
    if not isinstance(entry, dict):
        entry = {}
    entry.setdefault("last_best_error", None)
    entry.setdefault("stagnation_count", 0)
    entry.setdefault("last_resume_at", now)
    entry.setdefault("last_reason", None)
    entry.setdefault("target_stop_requested_at", None)
    entry.setdefault("timeout_stop_requested_at", None)
    entry.setdefault("last_poll_at", None)
    entry.setdefault("last_completed_tasks", 0)
    entry.setdefault("last_processing_tasks", 0)
    entry.setdefault("last_rel_improve", 0.0)
    entry.setdefault("throughput_tps", 0.0)
    entry.setdefault("health_score", 0.0)
    entry.setdefault("last_improve_at", None)
    entry.setdefault("piece_id", None)
    tasks[task_tag] = entry
    return entry


def _global_supervision_entry(global_state: Dict[str, Any]) -> Dict[str, Any]:
    global_state.setdefault("last_best_error", None)
    global_state.setdefault("plateau_count", 0)
    global_state.setdefault("last_resume_at", 0.0)
    global_state.setdefault("last_resume_reason", None)
    global_state.setdefault("last_manual_resume_at", 0.0)
    global_state.setdefault("repartition_focus", None)
    history = global_state.get("manual_resume_history")
    if not isinstance(history, list):
        history = []
    global_state["manual_resume_history"] = history
    return global_state


def _extract_progress_counts(poll_data: Dict[str, Any]) -> Tuple[int, int, int]:
    progress = poll_data.get("progress") or {}
    pending = _safe_int(progress.get("pending_tasks")) or 0
    processing = _safe_int(progress.get("processing_tasks")) or 0
    completed = _safe_int(progress.get("completed_tasks")) or 0
    return max(0, pending), max(0, processing), max(0, completed)


def _resolve_worker_budget() -> int:
    raw = os.getenv("ANUM_DEFAULT_WORKER_NUM")
    if raw is not None:
        parsed = _safe_int(raw)
        if parsed and parsed > 0:
            return parsed
    return max(1, os.cpu_count() or 1)


def _build_worker_plan(
    rows: List[Dict[str, Any]],
    total_workers: int,
    spec: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
    if not rows:
        return {
            "mode": "marginal_utility",
            "total_workers": total_workers,
            "per_task": {},
            "per_piece": {},
            "ranked_pieces": [],
            "primary_piece_id": None,
        }

    focus = build_repartition_focus(spec=spec, rows=rows, limit=max(1, len(rows)))
    ranked_pieces = focus.get("ranked_pieces") if isinstance(focus, dict) else []
    if not isinstance(ranked_pieces, list):
        ranked_pieces = []
    focus_by_task: Dict[str, Dict[str, Any]] = {}
    for item in ranked_pieces:
        if isinstance(item, dict) and item.get("task_tag") is not None:
            focus_by_task[str(item.get("task_tag"))] = item
    max_focus_score = max((_safe_float((item or {}).get("score")) or 0.0) for item in ranked_pieces) if ranked_pieces else 0.0
    max_throughput = max((_safe_float((row or {}).get("throughput_tps")) or 0.0) for row in rows) if rows else 0.0

    n = len(rows)
    budget = max(1, int(total_workers))
    base = [1] * n
    remaining = max(0, budget - n)

    weights: List[float] = []
    for row in rows:
        focus_row = focus_by_task.get(str(row.get("task_tag")))
        pressure = _safe_float((focus_row or {}).get("score")) or 0.0
        pressure_norm = pressure / max(1e-6, max_focus_score) if max_focus_score > 0 else 0.0
        rel_improve = max(0.0, _safe_float(row.get("last_rel_improve")) or 0.0)
        rel_norm = min(1.0, rel_improve * 25.0)
        throughput = max(0.0, _safe_float(row.get("throughput_tps")) or 0.0)
        throughput_norm = throughput / max(1.0, max_throughput) if max_throughput > 0 else 0.0
        health = min(1.0, max(0.0, _safe_float(row.get("health_score")) or 0.0))
        utility = (0.50 * rel_norm) + (0.35 * throughput_norm) + (0.15 * health)
        weight = 0.05 + (0.70 * utility) + (0.30 * pressure_norm)
        weights.append(max(1e-6, float(weight)))

    if remaining > 0:
        weight_sum = sum(weights) or 1.0
        shares = [remaining * (w / weight_sum) for w in weights]
        floors = [int(math.floor(x)) for x in shares]
        remainder = remaining - sum(floors)
        allocation = [base[i] + floors[i] for i in range(n)]
        order = sorted(
            range(n),
            key=lambda i: shares[i] - floors[i],
            reverse=True,
        )
        for idx in order[:remainder]:
            allocation[idx] += 1
    else:
        allocation = base

    per_task: Dict[str, int] = {}
    per_piece: Dict[str, int] = {}
    for idx, row in enumerate(rows):
        workers = max(1, int(allocation[idx]))
        task_tag = str(row.get("task_tag"))
        per_task[task_tag] = workers
        piece_id = row.get("piece_id")
        if piece_id is not None:
            per_piece[str(piece_id)] = workers

    return {
        "mode": "marginal_utility",
        "total_workers": budget,
        "per_task": per_task,
        "per_piece": per_piece,
        "ranked_pieces": ranked_pieces[:8],
        "primary_piece_id": focus.get("primary_piece_id") if isinstance(focus, dict) else None,
        "updated_at": int(time.time()),
    }


def _manual_resume_allowed(
    global_state: Dict[str, Any],
    now: float,
    cooldown_s: int,
    max_per_hour: int,
) -> bool:
    history_raw = global_state.get("manual_resume_history")
    if isinstance(history_raw, list):
        history = [_safe_float(item) for item in history_raw]
        history = [item for item in history if item is not None]
    else:
        history = []
    one_hour_ago = now - 3600.0
    history = [item for item in history if item >= one_hour_ago]
    global_state["manual_resume_history"] = history

    last_manual = _safe_float(global_state.get("last_manual_resume_at")) or 0.0
    if cooldown_s > 0 and last_manual > 0 and (now - last_manual) < cooldown_s:
        return False
    if max_per_hour > 0 and len(history) >= max_per_hour:
        return False
    return True


def _record_resume(
    global_state: Dict[str, Any],
    now: float,
    *,
    event: str,
    reason: str,
) -> None:
    global_state["last_resume_at"] = now
    global_state["last_resume_reason"] = reason
    if event != "manual_resume":
        return
    global_state["last_manual_resume_at"] = now
    history = global_state.get("manual_resume_history")
    if not isinstance(history, list):
        history = []
    history.append(now)
    global_state["manual_resume_history"] = history[-200:]


def _invoke_resume(
    exp_id: str,
    task_tag: str,
    event: str,
    piece_id: Optional[str],
    reason: str,
    timeout_s: int,
    *,
    state_manager: Optional[StateManager] = None,
    tool_client: Optional[Any] = None,
) -> None:
    try:
        resume_experiment(
            exp_id=exp_id,
            task_tag=task_tag,
            event=event,
            piece_id=piece_id,
            reason=reason,
            state_manager=state_manager,
            tool_client=tool_client,
        )
    except Exception as exc:
        print(
            "Resume failed for task_tag=%s event=%s reason=%s timeout_s=%s error=%s"
            % (task_tag, event, reason, timeout_s, exc)
        )


def _invoke_global_resume(
    exp_id: str,
    reason: str,
    timeout_s: int,
    *,
    state_manager: Optional[StateManager] = None,
    tool_client: Optional[Any] = None,
) -> None:
    try:
        resume_experiment(
            exp_id=exp_id,
            event="manual_resume",
            reason=reason,
            state_manager=state_manager,
            tool_client=tool_client,
        )
    except Exception as exc:
        print(
            "Global resume failed for exp_id=%s reason=%s timeout_s=%s error=%s"
            % (exp_id, reason, timeout_s, exc)
        )


def main() -> int:
    parser = argparse.ArgumentParser(description="Poll runs and resume agent when tasks complete.")
    parser.add_argument("--exp_id", type=str, required=True, help="Experiment ID")
    parser.add_argument("--interval_s", type=int, default=300, help="Polling interval in seconds")
    parser.add_argument(
        "--running_resume_interval_s",
        type=int,
        default=10800,
        help="If a task is still running, trigger a periodic manual resume at this interval.",
    )
    parser.add_argument(
        "--stagnation_rounds",
        type=int,
        default=30,
        help="Trigger a manual resume after this many no-improvement polls.",
    )
    parser.add_argument(
        "--manual_resume_min_interval_s",
        type=int,
        default=1200,
        help="Minimum spacing between routine manual resumes (hard floor is 1200s).",
    )
    parser.add_argument(
        "--stagnation_min_delta",
        type=float,
        default=None,
        help="Minimum best_error delta counted as improvement (default=max(1e-8, 0.05*threshold)).",
    )
    parser.add_argument(
        "--running_improve_epsilon",
        type=float,
        default=0.0,
        help="Backward-compatible alias merged into stagnation_min_delta.",
    )
    parser.add_argument(
        "--piece_timeout_s",
        type=int,
        default=0,
        help="Override piece timeout. 0 means use spec.stop_criteria.max_wall_time_s.",
    )
    parser.add_argument(
        "--target_stop_grace_s",
        type=int,
        default=30,
        help="Grace window after sending STOP_RUN on target reached before manual resume.",
    )
    parser.add_argument(
        "--disable_running_resume",
        action="store_true",
        help="Only resume on terminal task events (complete/failed/timeout).",
    )
    parser.add_argument(
        "--disable_auto_stop_on_target",
        action="store_true",
        help="Disable auto-stop when best error meets the spec metric threshold.",
    )
    parser.add_argument(
        "--resume_timeout_s",
        type=int,
        default=900,
        help="Advisory timeout metadata recorded with each resume invocation.",
    )
    parser.add_argument(
        "--max_resumes_per_round",
        type=int,
        default=int(os.getenv("ANUM_POLLER_MAX_RESUMES_PER_ROUND", "1")),
        help="Max number of resume invocations per poll round.",
    )
    parser.add_argument(
        "--global_manual_resume_cooldown_s",
        type=int,
        default=int(os.getenv("ANUM_POLLER_MANUAL_RESUME_COOLDOWN_S", "3600")),
        help="Global cooldown between manual_resume invocations.",
    )
    parser.add_argument(
        "--manual_resume_max_per_hour",
        type=int,
        default=int(os.getenv("ANUM_POLLER_MANUAL_RESUME_MAX_PER_HOUR", "4")),
        help="Global upper bound for manual_resume invocations within a rolling hour.",
    )
    parser.add_argument(
        "--global_plateau_rounds",
        type=int,
        default=6,
        help="Trigger a manual resume when global best error plateaus for this many polls.",
    )
    parser.add_argument(
        "--global_plateau_min_delta",
        type=float,
        default=None,
        help="Minimum global best_error drop counted as improvement (default=max(1e-8, 0.05*threshold)).",
    )
    parser.add_argument(
        "--imbalance_ratio",
        type=float,
        default=5.0,
        help="Trigger imbalance handling when worst_piece_error / best_piece_error exceeds this ratio.",
    )
    parser.add_argument(
        "--imbalance_stagnation_rounds",
        type=int,
        default=8,
        help="Require this many stagnation rounds on the worst piece before imbalance resume is triggered.",
    )
    parser.add_argument("--once", action="store_true", help="Run a single polling iteration")
    parser.add_argument("--max_rounds", type=int, default=0, help="Max polling rounds (0 = unlimited)")
    args = parser.parse_args()
    routine_resume_min_interval_s = max(1200, int(args.manual_resume_min_interval_s))
    heartbeat_interval_s = max(routine_resume_min_interval_s, int(args.running_resume_interval_s))

    state_manager = StateManager()
    if not state_manager.exists(args.exp_id):
        print(f"ERROR: Experiment {args.exp_id} not found")
        return 1

    trace = TraceLogger.from_env(
        exp_id=args.exp_id,
        state_manager=state_manager,
        component="poller",
    )
    tool_client = wrap_tool_client(build_internal_tool_client(), trace)
    if trace and trace.enabled:
        trace.log(
            "state_transitions",
            {
                "event": "poller_start",
                "exp_id": args.exp_id,
                "args": {
                    "interval_s": args.interval_s,
                    "running_resume_interval_s": args.running_resume_interval_s,
                    "stagnation_rounds": args.stagnation_rounds,
                    "global_plateau_rounds": args.global_plateau_rounds,
                    "imbalance_ratio": args.imbalance_ratio,
                    "max_resumes_per_round": args.max_resumes_per_round,
                    "manual_resume_cooldown_s": args.global_manual_resume_cooldown_s,
                    "manual_resume_max_per_hour": args.manual_resume_max_per_hour,
                },
            },
        )
    rounds = 0

    while True:
        rounds += 1
        state = state_manager.load(args.exp_id)
        pending = list(state.pending_callbacks)
        if not pending:
            phase = str(getattr(state, "phase", "") or "")
            global_keepalive = phase in ("planning", "evaluating") and not args.disable_running_resume
            if global_keepalive:
                now = time.time()
                supervision = _load_supervision_state(state_manager, args.exp_id)
                global_state = supervision.get("global")
                if not isinstance(global_state, dict):
                    global_state = {}
                    supervision["global"] = global_state
                global_state = _global_supervision_entry(global_state)

                manual_resume_cooldown_s = max(0, int(args.global_manual_resume_cooldown_s))
                manual_resume_max_per_hour = max(0, int(args.manual_resume_max_per_hour))
                last_resume_at = _safe_float(global_state.get("last_resume_at")) or 0.0
                due = (last_resume_at <= 0.0) or ((now - last_resume_at) >= max(1, heartbeat_interval_s))

                if due and _manual_resume_allowed(
                    global_state,
                    now,
                    cooldown_s=manual_resume_cooldown_s,
                    max_per_hour=manual_resume_max_per_hour,
                ):
                    print(
                        "No pending callbacks; trigger manual_resume exp_id=%s phase=%s reason=heartbeat"
                        % (args.exp_id, phase)
                    )
                    _invoke_global_resume(
                        args.exp_id,
                        reason="heartbeat",
                        timeout_s=args.resume_timeout_s,
                        state_manager=state_manager,
                        tool_client=tool_client,
                    )
                    _record_resume(global_state, now, event="manual_resume", reason="heartbeat")
                else:
                    print(
                        "No pending callbacks; waiting in phase=%s (next manual resume when due)"
                        % phase
                    )

                _save_supervision_state(state_manager, args.exp_id, supervision)
                if args.once:
                    return 0
                if args.max_rounds and rounds >= args.max_rounds:
                    return 0
                time.sleep(max(1, args.interval_s))
                continue

            print("No pending callbacks; exiting")
            if trace and trace.enabled:
                trace.log(
                    "state_transitions",
                    {"event": "poller_exit", "exp_id": args.exp_id, "reason": "no_pending_callbacks"},
                )
            return 0

        supervision = _load_supervision_state(state_manager, args.exp_id)
        task_state_map = supervision.setdefault("tasks", {})
        global_state = supervision.get("global")
        if not isinstance(global_state, dict):
            global_state = {}
            supervision["global"] = global_state
        global_state = _global_supervision_entry(global_state)
        for task_tag in list(task_state_map.keys()):
            if task_tag not in pending:
                task_state_map.pop(task_tag, None)

        threshold = _extract_target_threshold(state)
        running_rows: List[Dict[str, Any]] = []
        resume_triggered = False
        resume_invocations = 0
        max_resumes_per_round = max(1, int(args.max_resumes_per_round))
        manual_resume_cooldown_s = max(0, int(args.global_manual_resume_cooldown_s))
        manual_resume_max_per_hour = max(0, int(args.manual_resume_max_per_hour))

        for task_tag in pending:
            now = time.time()
            task_state = _task_supervision_entry(task_state_map, task_tag, now)
            piece_id = _find_piece_id(state, task_tag) if state.parallel_mode else None
            last_resume_at = _safe_float(task_state.get("last_resume_at"))
            routine_resume_due = (
                last_resume_at is None or (now - last_resume_at) >= routine_resume_min_interval_s
            )
            poll_response = tool_client.call("anum.run.poll", {"task_tag": task_tag})
            if poll_response.get("status") != "ok":
                continue
            poll_data = poll_response.get("data", {})
            poll_status = poll_data.get("status")
            is_running = poll_status in ("queued", "running")
            is_active = poll_status in ("queued", "running", "stopping")
            pending_tasks, processing_tasks, completed_tasks = _extract_progress_counts(poll_data)
            last_poll_at = _safe_float(task_state.get("last_poll_at"))
            poll_delta_s = (
                max(1.0, now - last_poll_at)
                if last_poll_at is not None
                else float(max(1, int(args.interval_s)))
            )
            prev_completed = _safe_int(task_state.get("last_completed_tasks"))
            if prev_completed is None:
                prev_completed = completed_tasks
            completed_delta = max(0, completed_tasks - prev_completed)
            task_state["last_poll_at"] = now
            task_state["last_completed_tasks"] = completed_tasks
            task_state["last_processing_tasks"] = processing_tasks
            if piece_id is not None:
                task_state["piece_id"] = piece_id

            event = None
            reason = "run_terminal"
            best_error = _extract_best_error(poll_response)
            piece_timeout_s = _extract_piece_timeout_s(state, args.piece_timeout_s)
            elapsed_s = _safe_float(poll_data.get("elapsed_s"))
            timeout_stop_sent_at = _safe_float(task_state.get("timeout_stop_requested_at"))

            min_delta = args.stagnation_min_delta
            if min_delta is None:
                min_delta = _default_stagnation_delta(threshold)
            if args.running_improve_epsilon > 0:
                min_delta = max(min_delta, args.running_improve_epsilon)

            rel_improve = 0.0
            if is_running and best_error is not None:
                previous = _safe_float(task_state.get("last_best_error"))
                if previous is None:
                    task_state["last_best_error"] = best_error
                    task_state["stagnation_count"] = 0
                else:
                    improvement = previous - best_error
                    rel_improve = max(0.0, improvement) / max(abs(previous), 1e-12)
                    if improvement > min_delta:
                        task_state["last_best_error"] = best_error
                        task_state["stagnation_count"] = 0
                        task_state["last_improve_at"] = now
                    else:
                        task_state["stagnation_count"] = int(task_state.get("stagnation_count", 0)) + 1
            task_state["last_rel_improve"] = rel_improve
            throughput_tps = completed_delta / max(1.0, poll_delta_s)
            task_state["throughput_tps"] = throughput_tps
            improve_norm = min(1.0, max(0.0, rel_improve * 200.0))
            throughput_norm = min(1.0, throughput_tps / max(1.0, float(max(1, processing_tasks))))
            activity_norm = 1.0 if processing_tasks > 0 else (0.5 if pending_tasks > 0 else 0.0)
            task_state["health_score"] = round(
                (0.55 * improve_norm) + (0.35 * throughput_norm) + (0.10 * activity_norm),
                6,
            )

            target_branch_taken = False
            if (
                is_running
                and not args.disable_auto_stop_on_target
                and threshold is not None
                and best_error is not None
                and best_error <= threshold
            ):
                target_branch_taken = True
                stop_sent_at = _safe_float(task_state.get("target_stop_requested_at"))
                if stop_sent_at is None:
                    stop_resp = tool_client.call("anum.run.stop", {"task_tag": task_tag})
                    print(
                        "Auto-stop triggered for task_tag=%s: best_error=%s <= threshold=%s (resp=%s)"
                        % (task_tag, best_error, threshold, stop_resp.get("status"))
                    )
                    if stop_resp.get("status") == "ok":
                        task_state["target_stop_requested_at"] = now
                        task_state["last_reason"] = "target_reached"
                else:
                    grace_s = max(1, args.target_stop_grace_s)
                    if (
                        not args.disable_running_resume
                        and (now - stop_sent_at) >= grace_s
                    ):
                        event = "manual_resume"
                        reason = "target_reached"
                        task_state["last_resume_at"] = now
                        task_state["last_reason"] = reason
                        # Keep throttling on long-running stop requests.
                        task_state["target_stop_requested_at"] = now

            if (
                is_running
                and event is None
                and not target_branch_taken
                and piece_timeout_s is not None
                and elapsed_s is not None
                and elapsed_s >= piece_timeout_s
                and timeout_stop_sent_at is None
            ):
                stop_resp = tool_client.call("anum.run.stop", {"task_tag": task_tag})
                print(
                    "Piece-timeout stop for task_tag=%s: elapsed_s=%s >= timeout_s=%s (resp=%s)"
                    % (task_tag, elapsed_s, piece_timeout_s, stop_resp.get("status"))
                )
                if stop_resp.get("status") == "ok":
                    task_state["timeout_stop_requested_at"] = now
                    timeout_stop_sent_at = now
                    task_state["target_stop_requested_at"] = None
                else:
                    # Stop failure is exceptional; resume immediately for intervention.
                    event = "manual_resume"
                    reason = "run_terminal"
                    task_state["last_resume_at"] = now
                    task_state["last_reason"] = reason

            if (
                is_running
                and event is None
                and not target_branch_taken
                and not args.disable_running_resume
                and timeout_stop_sent_at is None
                and routine_resume_due
                and int(task_state.get("stagnation_count", 0)) >= max(1, int(args.stagnation_rounds))
            ):
                event = "manual_resume"
                reason = "stagnation"
                task_state["last_resume_at"] = now
                task_state["last_reason"] = reason
                task_state["stagnation_count"] = 0

            if (
                is_running
                and event is None
                and not target_branch_taken
                and not args.disable_running_resume
                and timeout_stop_sent_at is None
            ):
                due = (now - float(task_state.get("last_resume_at", now))) >= max(1, heartbeat_interval_s)
                if due:
                    event = "manual_resume"
                    reason = "heartbeat"
                    task_state["last_resume_at"] = now
                    task_state["last_reason"] = reason
            if (
                event is None
                and not is_active
                and timeout_stop_sent_at is not None
            ):
                event = "timeout"
                reason = "piece_timeout"
                task_state["last_resume_at"] = now
                task_state["last_reason"] = reason
                task_state["timeout_stop_requested_at"] = None
                timeout_stop_sent_at = None
                task_state["target_stop_requested_at"] = None

            if event is None and poll_status == "stopping":
                result_resp = tool_client.call("anum.run.result", {"task_tag": task_tag})
                result_status = result_resp.get("status")
                if result_status == "ok":
                    event = "task_complete"
                    if _safe_float(task_state.get("target_stop_requested_at")) is not None:
                        reason = "target_reached"
                    elif timeout_stop_sent_at is not None:
                        reason = "piece_timeout"
                    else:
                        reason = "run_terminal"
                    task_state["timeout_stop_requested_at"] = None
                    timeout_stop_sent_at = None
                    task_state["target_stop_requested_at"] = None
                    task_state["last_resume_at"] = now
                    task_state["last_reason"] = reason
                elif result_status in ("error", "not_found"):
                    if timeout_stop_sent_at is not None:
                        event = "timeout"
                        reason = "piece_timeout"
                    else:
                        event = "task_failed"
                        reason = "run_terminal"
                    task_state["timeout_stop_requested_at"] = None
                    timeout_stop_sent_at = None
                    task_state["target_stop_requested_at"] = None
                    task_state["last_resume_at"] = now
                    task_state["last_reason"] = reason

            if event is None and not is_active:
                result_resp = tool_client.call("anum.run.result", {"task_tag": task_tag})
                result_status = result_resp.get("status")
                if result_status == "ok":
                    event = "task_complete"
                    if _safe_float(task_state.get("target_stop_requested_at")) is not None:
                        reason = "target_reached"
                    else:
                        reason = "run_terminal"
                elif result_status == "running":
                    if poll_status not in (None, "queued", "running", "stopping"):
                        event = "task_failed"
                        reason = "run_terminal"
                    else:
                        event = None
                else:
                    event = "task_failed"
                    reason = "run_terminal"

            if is_running and best_error is not None:
                running_rows.append(
                    {
                        "task_tag": task_tag,
                        "piece_id": piece_id,
                        "best_error": best_error,
                        "error": best_error,
                        "status": poll_status,
                        "stagnation_count": int(task_state.get("stagnation_count", 0)),
                        "health_score": _safe_float(task_state.get("health_score")) or 0.0,
                        "throughput_tps": throughput_tps,
                        "last_rel_improve": rel_improve,
                        "routine_resume_due": routine_resume_due,
                        "timeout_stop_pending": timeout_stop_sent_at is not None,
                        "target_stop_pending": _safe_float(task_state.get("target_stop_requested_at")) is not None,
                    }
                )

            if event is None:
                continue
            if event == "manual_resume" and resume_invocations >= max_resumes_per_round:
                if trace and trace.enabled:
                    trace.log(
                        "decisions",
                        {
                            "event": "resume_skipped",
                            "exp_id": args.exp_id,
                            "task_tag": task_tag,
                            "reason": "max_resumes_per_round",
                            "candidate_event": event,
                            "candidate_reason": reason,
                        },
                    )
                continue
            if event == "manual_resume" and not _manual_resume_allowed(
                global_state,
                now,
                cooldown_s=manual_resume_cooldown_s,
                max_per_hour=manual_resume_max_per_hour,
            ):
                if trace and trace.enabled:
                    trace.log(
                        "decisions",
                        {
                            "event": "resume_skipped",
                            "exp_id": args.exp_id,
                            "task_tag": task_tag,
                            "reason": "manual_resume_throttled",
                            "candidate_event": event,
                            "candidate_reason": reason,
                        },
                    )
                continue
            print(f"Resuming for task_tag={task_tag} status={poll_status} event={event} reason={reason}")
            _invoke_resume(
                args.exp_id,
                task_tag,
                event,
                piece_id,
                reason,
                timeout_s=args.resume_timeout_s,
                state_manager=state_manager,
                tool_client=tool_client,
            )
            _record_resume(global_state, now, event=event, reason=reason)
            if trace and trace.enabled:
                trace.log(
                    "decisions",
                    {
                        "event": "resume_invoked",
                        "exp_id": args.exp_id,
                        "task_tag": task_tag,
                        "piece_id": piece_id,
                        "resume_event": event,
                        "resume_reason": reason,
                        "poll_status": poll_status,
                    },
                )
            resume_triggered = True
            if event == "manual_resume":
                resume_invocations += 1
            if event in ("task_complete", "task_failed", "timeout"):
                task_state_map.pop(task_tag, None)

        worker_plan = _build_worker_plan(running_rows, _resolve_worker_budget(), spec=state.current_spec)
        global_state["worker_plan"] = worker_plan
        repartition_focus = build_repartition_focus(
            spec=state.current_spec,
            rows=running_rows,
            worker_plan=worker_plan,
            limit=8,
        )
        if isinstance(repartition_focus, dict):
            repartition_focus = dict(repartition_focus)
            repartition_focus["updated_at"] = int(time.time())
            repartition_focus["threshold"] = threshold
            global_state["repartition_focus"] = repartition_focus
        if trace and trace.enabled:
            trace.log(
                "worker_plans",
                {
                    "event": "worker_plan",
                    "exp_id": args.exp_id,
                    "round": rounds,
                    "plan": worker_plan,
                    "repartition_focus": global_state.get("repartition_focus"),
                    "running_rows": running_rows,
                },
            )

        if running_rows:
            now = time.time()
            # Piecewise global objective: minimize worst-piece error.
            global_best_error = max(float(row["best_error"]) for row in running_rows)
            global_min_delta = args.global_plateau_min_delta
            if global_min_delta is None:
                global_min_delta = _default_stagnation_delta(threshold)
            previous_global_best = _safe_float(global_state.get("last_best_error"))
            if previous_global_best is None or (previous_global_best - global_best_error) > global_min_delta:
                global_state["last_best_error"] = global_best_error
                global_state["plateau_count"] = 0
                global_state["last_improve_at"] = now
            else:
                global_state["plateau_count"] = int(global_state.get("plateau_count", 0)) + 1

            global_last_resume_at = _safe_float(global_state.get("last_resume_at")) or 0.0
            global_resume_due = (now - global_last_resume_at) >= routine_resume_min_interval_s
            plateau_rounds = max(1, int(args.global_plateau_rounds))
            if (
                not resume_triggered
                and not args.disable_running_resume
                and global_resume_due
                and int(global_state.get("plateau_count", 0)) >= plateau_rounds
                and resume_invocations < max_resumes_per_round
            ):
                target = max(
                    running_rows,
                    key=lambda item: (
                        float(item.get("best_error", 0.0)),
                        int(item.get("stagnation_count", 0)),
                    ),
                )
                if not target.get("timeout_stop_pending") and not target.get("target_stop_pending"):
                    if not _manual_resume_allowed(
                        global_state,
                        now,
                        cooldown_s=manual_resume_cooldown_s,
                        max_per_hour=manual_resume_max_per_hour,
                    ):
                        if trace and trace.enabled:
                            trace.log(
                                "decisions",
                                {
                                    "event": "resume_skipped",
                                    "exp_id": args.exp_id,
                                    "task_tag": target["task_tag"],
                                    "reason": "manual_resume_throttled",
                                    "candidate_event": "manual_resume",
                                    "candidate_reason": "plateau",
                                },
                            )
                        pass
                    else:
                        print(
                            "Global plateau detected; triggering resume for task_tag=%s"
                            % target["task_tag"]
                        )
                        _invoke_resume(
                            args.exp_id,
                            str(target["task_tag"]),
                            "manual_resume",
                            target.get("piece_id"),
                            "plateau",
                            timeout_s=args.resume_timeout_s,
                            state_manager=state_manager,
                            tool_client=tool_client,
                        )
                        target_state = task_state_map.get(str(target["task_tag"]))
                        if isinstance(target_state, dict):
                            target_state["last_resume_at"] = now
                            target_state["last_reason"] = "plateau"
                        _record_resume(global_state, now, event="manual_resume", reason="plateau")
                        if trace and trace.enabled:
                            trace.log(
                                "decisions",
                                {
                                    "event": "resume_invoked",
                                    "exp_id": args.exp_id,
                                    "task_tag": target["task_tag"],
                                    "piece_id": target.get("piece_id"),
                                    "resume_event": "manual_resume",
                                    "resume_reason": "plateau",
                                    "plateau_count": global_state.get("plateau_count"),
                                },
                            )
                        global_state["plateau_count"] = 0
                        resume_triggered = True
                        resume_invocations += 1

            if (
                not resume_triggered
                and not args.disable_running_resume
                and global_resume_due
                and len(running_rows) >= 2
                and resume_invocations < max_resumes_per_round
            ):
                best_row = min(running_rows, key=lambda item: float(item.get("best_error", float("inf"))))
                worst_row = max(running_rows, key=lambda item: float(item.get("best_error", 0.0)))
                best_value = float(max(best_row.get("best_error", 0.0), 0.0))
                worst_value = float(max(worst_row.get("best_error", 0.0), 0.0))
                ratio = float("inf") if best_value <= 0 else worst_value / best_value
                worst_stagnation = int(worst_row.get("stagnation_count", 0))
                if (
                    ratio >= max(1.0, float(args.imbalance_ratio))
                    and worst_stagnation >= max(1, int(args.imbalance_stagnation_rounds))
                    and not worst_row.get("timeout_stop_pending")
                    and not worst_row.get("target_stop_pending")
                ):
                    if not _manual_resume_allowed(
                        global_state,
                        now,
                        cooldown_s=manual_resume_cooldown_s,
                        max_per_hour=manual_resume_max_per_hour,
                    ):
                        if trace and trace.enabled:
                            trace.log(
                                "decisions",
                                {
                                    "event": "resume_skipped",
                                    "exp_id": args.exp_id,
                                    "task_tag": worst_row["task_tag"],
                                    "reason": "manual_resume_throttled",
                                    "candidate_event": "manual_resume",
                                    "candidate_reason": "imbalance",
                                },
                            )
                        pass
                    else:
                        print(
                            "Piece imbalance detected (ratio=%s); triggering resume for task_tag=%s"
                            % (ratio, worst_row["task_tag"])
                        )
                        _invoke_resume(
                            args.exp_id,
                            str(worst_row["task_tag"]),
                            "manual_resume",
                            worst_row.get("piece_id"),
                            "imbalance",
                            timeout_s=args.resume_timeout_s,
                            state_manager=state_manager,
                            tool_client=tool_client,
                        )
                        target_state = task_state_map.get(str(worst_row["task_tag"]))
                        if isinstance(target_state, dict):
                            target_state["last_resume_at"] = now
                            target_state["last_reason"] = "imbalance"
                        _record_resume(global_state, now, event="manual_resume", reason="imbalance")
                        if trace and trace.enabled:
                            trace.log(
                                "decisions",
                                {
                                    "event": "resume_invoked",
                                    "exp_id": args.exp_id,
                                    "task_tag": worst_row["task_tag"],
                                    "piece_id": worst_row.get("piece_id"),
                                    "resume_event": "manual_resume",
                                    "resume_reason": "imbalance",
                                    "imbalance_ratio": ratio,
                                },
                            )
                        resume_triggered = True
                        resume_invocations += 1

        _save_supervision_state(state_manager, args.exp_id, supervision)
        if trace and trace.enabled:
            trace.log(
                "decisions",
                {
                    "event": "poll_round_complete",
                    "exp_id": args.exp_id,
                    "round": rounds,
                    "pending_callbacks": pending,
                    "resume_invocations": resume_invocations,
                },
            )

        if args.once:
            return 0
        if args.max_rounds and rounds >= args.max_rounds:
            return 0
        time.sleep(max(1, args.interval_s))


if __name__ == "__main__":
    raise SystemExit(main())
