"""Internal experiment-resume engine used by the detached poller.

This module loads persisted agent state, builds an observation from run status,
and advances the async runtime to its next decision. The public entrypoint is
`python -m agent.cli`; this module remains an internal helper/debug utility.
"""
from __future__ import annotations

import argparse
import os
import sys
import time
from typing import Any, Dict, List, Optional

# Add project root to path
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, root_dir)

from agent.event_log import EventLog
from agent.experiment_view import fetch_log_tail
from agent.state import AgentState
from agent import state_machine
from agent.state_manager import StateManager
from agent.trace import TraceLogger, wrap_tool_client
from agent.tooling import build_internal_tool_client
from server.run_manager import RunManager


def load_task_result(
    task_tag: str,
    base_op_folder: str = "./op",
    tool_client: Optional[Any] = None,
) -> Dict[str, Any]:
    """Load the result from a completed task.

    Args:
        task_tag: The task tag identifying the search run
        base_op_folder: Base folder for operation files

    Returns:
        Dict containing:
            - status: "completed" | "failed" | "running"
            - best_candidate: Best DAG found (if any)
            - error_message: Error details (if failed)
            - summary: Progress summary
    """
    if tool_client is not None:
        result_response = tool_client.call("anum.run.result", {"task_tag": task_tag})
    else:
        manager = RunManager(base_op_folder=base_op_folder)
        result_response = manager.result(task_tag)

    if result_response.get("status") == "ok":
        data = result_response.get("data", {})
        best = data.get("best_candidate", {})
        return {
            "status": "completed",
            "best_candidate": best,
            "optimization_error": best.get("optimization_error"),
            "ops": best.get("ops"),
            "artifact_id": best.get("artifact_id"),
            "summary": data,
        }
    elif result_response.get("status") == "running":
        return {
            "status": "running",
            "error_message": "Task still running",
            "errors": result_response.get("errors", []),
        }
    else:
        return {
            "status": "failed",
            "error_message": "Failed to get result",
            "errors": result_response.get("errors", []),
        }


def format_observation(
    result: Dict[str, Any],
    event: str,
    task_tag: str,
    piece_id: Optional[str] = None,
    supervision_reason: Optional[str] = None,
) -> Dict[str, Any]:
    """Format the task result as an observation for the LLM.

    Args:
        result: Task result from load_task_result
        event: Event type (task_complete, task_failed, timeout)
        task_tag: Task tag for the result
        piece_id: Optional piece ID for parallel execution

    Returns:
        Observation dict suitable for LLM context
    """
    observation = {
        "event": event,
        "timestamp": int(time.time()),
        "piece_id": piece_id,
        "task_status": result.get("status"),
        "task_tag": task_tag,
    }
    if supervision_reason:
        observation["supervision_reason"] = supervision_reason
    if result.get("poll_status"):
        observation["poll_status"] = result.get("poll_status")

    if result.get("status") == "completed":
        observation["best_result"] = {
            "optimization_error": result.get("optimization_error"),
            "ops": result.get("ops"),
            "artifact_id": result.get("artifact_id"),
        }
        observation["message"] = (
            f"Search completed. Best error: {result.get('optimization_error')}, "
            f"Operations: {result.get('ops')}"
        )
    elif result.get("status") == "failed":
        observation["error"] = result.get("error_message")
        observation["message"] = f"Search failed: {result.get('error_message')}"
    else:
        observation["message"] = f"Task status: {result.get('status')}"
        if result.get("progress") is not None:
            observation["progress"] = result.get("progress")
        if result.get("best_summary") is not None:
            observation["best_summary"] = result.get("best_summary")

    if result.get("errors"):
        observation["errors"] = result.get("errors")
    if result.get("logs"):
        observation["logs"] = result.get("logs")
    if result.get("supervision"):
        observation["supervision"] = result.get("supervision")

    return observation


def _manual_resume_event_payload(result: Dict[str, Any]) -> Dict[str, Any]:
    compact: Dict[str, Any] = {}
    for key in ("status", "poll_status", "progress", "best_summary", "errors", "supervision"):
        value = result.get(key)
        if value is not None:
            compact[key] = value
    return compact


def _load_supervision_snapshot(
    state_manager: StateManager,
    exp_id: str,
    task_tag: str,
) -> Optional[Dict[str, Any]]:
    try:
        payload = state_manager.load_supervision(exp_id)
    except Exception:
        return None
    if not isinstance(payload, dict):
        return None
    task_state = (payload.get("tasks") or {}).get(task_tag)
    global_state = payload.get("global") or {}
    if not isinstance(task_state, dict):
        task_state = {}
    if not isinstance(global_state, dict):
        global_state = {}
    repartition_focus = global_state.get("repartition_focus")
    task_repartition_context = None
    if isinstance(repartition_focus, dict):
        ranked = repartition_focus.get("ranked_pieces")
        if isinstance(ranked, list):
            for item in ranked:
                if not isinstance(item, dict):
                    continue
                if str(item.get("task_tag") or "") == task_tag:
                    task_repartition_context = item
                    break
    snapshot = {
        "task": {
            "last_best_error": task_state.get("last_best_error"),
            "stagnation_count": task_state.get("stagnation_count"),
            "health_score": task_state.get("health_score"),
            "throughput_tps": task_state.get("throughput_tps"),
            "last_reason": task_state.get("last_reason"),
            "piece_id": task_state.get("piece_id"),
            "repartition_context": task_repartition_context,
        },
        "global": {
            "last_best_error": global_state.get("last_best_error"),
            "plateau_count": global_state.get("plateau_count"),
            "last_resume_reason": global_state.get("last_resume_reason"),
            "worker_plan": global_state.get("worker_plan"),
            "repartition_focus": repartition_focus,
        },
    }
    return snapshot


def handle_piece_callback(
    state_manager: StateManager,
    exp_id: str,
    piece_id: str,
    event: str,
    result: Dict[str, Any],
    task_tag: str,
    event_id: Optional[str],
) -> bool:
    """Handle callback for a single piece in parallel execution.

    Updates the piece status atomically. If all pieces are done,
    returns True to indicate the agent should be resumed.

    Args:
        state_manager: StateManager instance
        exp_id: Experiment ID
        piece_id: Piece ID that completed
        event: Event type
        result: Task result

    Returns:
        True if all pieces done and agent should be resumed
    """
    status = "completed" if result.get("status") == "completed" else "failed"
    error = result.get("error_message") if status == "failed" else None

    state = state_manager.update_piece_status(
        exp_id=exp_id,
        piece_id=piece_id,
        status=status,
        result=result,
        error=error,
        task_tag=task_tag,
        event_id=event_id,
        run_status=status,
        update_pending=True,
    )

    return state.all_pieces_done()


def aggregate_piece_results(state: AgentState) -> Dict[str, Any]:
    """Aggregate results from all completed pieces.

    Args:
        state: AgentState with piece_statuses

    Returns:
        Aggregated observation dict
    """
    successful = state.get_successful_pieces()
    failed = state.get_failed_pieces()

    best_overall = None
    best_error = float("inf")
    worst_error = None
    threshold = None
    try:
        threshold = float(((state.current_spec or {}).get("metric") or {}).get("threshold"))
    except (TypeError, ValueError):
        threshold = None

    for piece in successful:
        if piece.result:
            error = piece.result.get("optimization_error")
            if error is not None and error < best_error:
                best_error = error
                best_overall = piece.result
            if error is not None:
                if worst_error is None or error > worst_error:
                    worst_error = error

    observation = {
        "event": "all_pieces_complete",
        "timestamp": int(time.time()),
        "total_pieces": len(state.piece_statuses),
        "successful_pieces": len(successful),
        "failed_pieces": len(failed),
        "best_overall": best_overall,
        "best_piece_error": best_error if best_overall else None,
        "worst_piece_error": worst_error,
        "metric_threshold": threshold,
    }
    if worst_error is not None and threshold is not None and not failed:
        observation["all_pieces_pass_threshold"] = bool(worst_error <= threshold)

    if failed:
        observation["failed_piece_ids"] = [p.piece_id for p in failed]
        observation["message"] = (
            f"Parallel search complete. {len(successful)} succeeded, "
            f"{len(failed)} failed. Best error: {best_error if best_overall else 'N/A'}, "
            f"Worst error: {worst_error if worst_error is not None else 'N/A'}"
        )
    else:
        observation["message"] = (
            f"All {len(successful)} pieces completed successfully. "
            f"Best error: {best_error if best_overall else 'N/A'}, "
            f"Worst error: {worst_error if worst_error is not None else 'N/A'}"
        )

    return observation


def find_run_id(state: AgentState, task_tag: str) -> Optional[str]:
    for run in state.runs:
        if run.task_tag == task_tag:
            return run.run_id
    return None


def resume_experiment(
    *,
    exp_id: str,
    piece_id: Optional[str] = None,
    event: str = "task_complete",
    reason: str = "run_terminal",
    task_tag: Optional[str] = None,
    dry_run: bool = False,
    state_manager: Optional[StateManager] = None,
    tool_client: Optional[Any] = None,
) -> int:
    print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Agent resume triggered")
    print(f"  exp_id: {exp_id}")
    print(f"  piece_id: {piece_id}")
    print(f"  event: {event}")
    print(f"  reason: {reason}")

    state_manager = state_manager or StateManager()
    trace = TraceLogger.from_env(
        exp_id=exp_id,
        state_manager=state_manager,
        component="resume",
    )
    if tool_client is None:
        tool_client = wrap_tool_client(build_internal_tool_client(), trace)
    if trace and trace.enabled:
        trace.log(
            "state_transitions",
            {
                "event": "resume_start",
                "exp_id": exp_id,
                "event_type": event,
                "reason": reason,
                "piece_id": piece_id,
                "task_tag": task_tag,
                "dry_run": dry_run,
            },
        )

    # Check if experiment exists
    if not state_manager.exists(exp_id):
        print(f"ERROR: Experiment {exp_id} not found")
        return 1

    # Load state
    state = state_manager.load(exp_id)
    print(f"  phase: {state.phase}")
    print(f"  pending_callbacks: {state.pending_callbacks}")

    # Determine task_tag
    resolved_task_tag = task_tag
    if not resolved_task_tag and piece_id and piece_id in state.piece_statuses:
        resolved_task_tag = state.piece_statuses[piece_id].task_tag
    elif not task_tag and state.pending_callbacks:
        resolved_task_tag = state.pending_callbacks[-1]

    if not resolved_task_tag and event == "manual_resume":
        print("  no pending task tag; running planning/evaluating steps directly")
        if trace and trace.enabled:
            trace.log(
                "decisions",
                {
                    "event": "resume_without_task_tag",
                    "exp_id": exp_id,
                    "reason": "manual_resume_without_pending",
                },
            )
        if dry_run:
            print("  [DRY RUN] Would invoke async runtime without callback observation")
            return 0
        try:
            from agent.runtime import Runtime

            runtime = Runtime.from_state(state, state_manager, tool_client=tool_client)
            runtime._run_steps()
            print(f"  Agent tick completed. New phase: {runtime.state.phase}")
            print(f"  pending_callbacks: {runtime.state.pending_callbacks}")
            return 0
        except Exception as exc:
            print(f"ERROR: planning resume failed: {exc}")
            return 1

    if not resolved_task_tag:
        print("ERROR: Could not determine task_tag")
        return 1

    task_tag = resolved_task_tag
    print(f"  task_tag: {task_tag}")

    poll_response = tool_client.call("anum.run.poll", {"task_tag": task_tag})
    poll_data = poll_response.get("data", {}) if poll_response.get("status") == "ok" else {}
    poll_status = poll_data.get("status")

    # Load task result
    if event == "manual_resume":
        if poll_status in ("queued", "running"):
            result = {
                "status": "running",
                "poll_status": poll_status,
                "progress": poll_data.get("progress"),
                "best_summary": poll_data.get("best_summary"),
            }
        else:
            result = load_task_result(task_tag, tool_client=tool_client)
        supervision_snapshot = _load_supervision_snapshot(state_manager, exp_id, task_tag)
        if supervision_snapshot:
            result["supervision"] = supervision_snapshot
    else:
        result = load_task_result(task_tag, tool_client=tool_client)

    if poll_status and result.get("poll_status") is None:
        result["poll_status"] = poll_status

    effective_event = event
    if (
        event == "manual_resume"
        and poll_status not in ("queued", "running")
        and result.get("status") in ("completed", "failed")
    ):
        effective_event = "task_complete" if result.get("status") == "completed" else "task_failed"

    is_terminal_event = effective_event in ("task_complete", "task_failed", "timeout")
    if is_terminal_event and result.get("status") == "running" and poll_status and poll_status not in ("queued", "running"):
        result = {
            "status": "failed",
            "error_message": f"Task ended without best result (poll_status={poll_status})",
            "errors": result.get("errors", []) + [
                {
                    "code": "no_best_result",
                    "message": "run ended without any best candidate",
                    "details": {"poll_status": poll_status},
                }
            ],
            "poll_status": poll_status,
            "progress": poll_data.get("progress"),
            "best_summary": poll_data.get("best_summary"),
        }

    include_logs = event != "manual_resume" and (
        result.get("status") != "completed" or bool(result.get("errors"))
    )
    if include_logs:
        logs = fetch_log_tail(tool_client, task_tag, lines=50, max_files=2)
        if logs:
            result["logs"] = logs
    print(f"  result status: {result.get('status')}")
    if result.get("optimization_error"):
        print(f"  best error: {result.get('optimization_error')}")
    if trace and trace.enabled:
        trace.log(
            "decisions",
            {
                "event": "resume_result_loaded",
                "exp_id": exp_id,
                "task_tag": task_tag,
                "effective_event": effective_event,
                "result": result,
            },
        )

    run_id = find_run_id(state, task_tag)
    event_log = EventLog()
    if effective_event == "manual_resume":
        event_payload = _manual_resume_event_payload(result)
    else:
        event_payload = dict(result)
    event_payload["supervision_reason"] = reason
    appended, event_record = event_log.append_event(
        exp_id=exp_id,
        event_type=effective_event,
        task_tag=task_tag,
        piece_id=piece_id,
        run_id=run_id,
        payload=event_payload,
    )

    if not appended:
        print("  duplicate event detected; ignoring")
        if trace and trace.enabled:
            trace.log(
                "decisions",
                {
                    "event": "resume_duplicate_event",
                    "exp_id": exp_id,
                    "task_tag": task_tag,
                    "effective_event": effective_event,
                },
            )
        return 0

    event_id = event_record.get("event_id")
    run_status = None
    if is_terminal_event:
        run_status = "completed" if result.get("status") == "completed" else "failed"

    # Handle parallel piece execution
    if state.parallel_mode and state.piece_statuses:
        if effective_event == "manual_resume":
            state.last_event_id = event_id
            state.updated_at = int(time.time())
            state_manager.save(exp_id, state)
            observation = format_observation(
                result,
                effective_event,
                task_tag,
                piece_id,
                supervision_reason=reason,
            )
        elif piece_id and is_terminal_event:
            all_done = handle_piece_callback(
                state_manager=state_manager,
                exp_id=exp_id,
                piece_id=piece_id,
                event=effective_event,
                result=result,
                task_tag=task_tag,
                event_id=event_id,
            )

            if not all_done:
                if effective_event == "task_complete":
                    print(f"  Piece {piece_id} recorded, waiting for other pieces")
                    return 0
                # For failure/timeout events, wake the agent immediately to re-plan this piece.
                state = state_manager.load(exp_id)
                observation = format_observation(
                    result,
                    effective_event,
                    task_tag,
                    piece_id,
                    supervision_reason=reason,
                )
                observation["parallel_partial_terminal"] = True
                observation["pending_piece_count"] = len(state.pending_callbacks)
                print("  Piece terminal event requires immediate intervention; resuming agent")
            else:
                # All pieces done - reload state and aggregate
                state = state_manager.load(exp_id)
                observation = aggregate_piece_results(state)
                observation["supervision_reason"] = reason

        else:
            if not state.all_pieces_done():
                print("  Parallel pieces still running; skipping resume")
                return 0
            state.last_event_id = event_id
            state_manager.save(exp_id, state)
            observation = aggregate_piece_results(state)
            observation["supervision_reason"] = reason
    else:
        if is_terminal_event and task_tag in state.pending_callbacks:
            state.pending_callbacks.remove(task_tag)
        if run_status and task_tag:
            for run in state.runs:
                if run.task_tag == task_tag:
                    run.status = run_status
                    break
        state.last_event_id = event_id
        state.updated_at = int(time.time())
        state_manager.save(exp_id, state)
        observation = format_observation(
            result,
            effective_event,
            task_tag,
            piece_id,
            supervision_reason=reason,
        )

    print(f"  observation: {observation.get('message')}")
    if trace and trace.enabled:
            trace.log(
                "decisions",
                {
                    "event": "resume_observation_built",
                    "exp_id": exp_id,
                    "task_tag": task_tag,
                    "effective_event": effective_event,
                    "observation": observation,
                },
            )

    if dry_run:
        print("  [DRY RUN] Would invoke async runtime with observation")
        import json

        print(json.dumps(observation, indent=2, ensure_ascii=False))
        return 0

    # Import and invoke async runtime
    try:
        from agent.runtime import Runtime

        runtime = Runtime.from_state(state, state_manager)
        runtime.step(observation)
        print(f"  Agent step completed. New phase: {runtime.state.phase}")
        if trace and trace.enabled:
            trace.log(
                "state_transitions",
                {
                    "event": "resume_runtime_step_completed",
                    "exp_id": exp_id,
                    "phase": runtime.state.phase,
                    "pending_callbacks": list(runtime.state.pending_callbacks),
                },
            )
    except ImportError:
        print("  WARNING: Runtime not available, saving observation to state")
        # Fallback: just update state with observation
        state.history.append(({"action": "CALLBACK_RECEIVED"}, observation))
        state_machine.apply_transition(state, "evaluating")
        state_manager.save(exp_id, state)
        if trace and trace.enabled:
            trace.log(
                "state_transitions",
                {
                    "event": "resume_fallback_saved_observation",
                    "exp_id": exp_id,
                },
            )

    return 0


def main() -> int:
    parser = argparse.ArgumentParser(description="Resume agent after task callback")
    parser.add_argument("--exp_id", type=str, required=True, help="Experiment ID")
    parser.add_argument("--piece_id", type=str, default=None, help="Piece ID (for parallel)")
    parser.add_argument(
        "--event",
        type=str,
        default="task_complete",
        choices=["task_complete", "task_failed", "timeout", "manual_resume"],
        help="Event type",
    )
    parser.add_argument(
        "--reason",
        type=str,
        default="run_terminal",
        choices=[
            "target_reached",
            "stagnation",
            "heartbeat",
            "piece_timeout",
            "plateau",
            "imbalance",
            "run_terminal",
        ],
        help="Supervision reason associated with this resume trigger.",
    )
    parser.add_argument("--task_tag", type=str, default=None, help="Task tag (optional)")
    parser.add_argument(
        "--dry_run",
        action="store_true",
        help="Load state and result but don't invoke LLM",
    )
    args = parser.parse_args()
    return resume_experiment(
        exp_id=args.exp_id,
        piece_id=args.piece_id,
        event=args.event,
        reason=args.reason,
        task_tag=args.task_tag,
        dry_run=args.dry_run,
    )


if __name__ == "__main__":
    sys.exit(main())
