from __future__ import annotations

from typing import Dict, Iterable, Set

from agent.actions import ActionType
from agent.state import AgentState


PHASE_INIT = "init"
PHASE_PLANNING = "planning"
PHASE_RUNNING = "running"
PHASE_EVALUATING = "evaluating"
PHASE_FINALIZED = "finalized"
PHASE_FAILED = "failed"
PHASE_STOPPED = "stopped"


ALLOWED_TRANSITIONS: Dict[str, Set[str]] = {
    PHASE_INIT: {PHASE_PLANNING},
    PHASE_PLANNING: {PHASE_RUNNING, PHASE_EVALUATING, PHASE_FINALIZED},
    PHASE_RUNNING: {PHASE_EVALUATING},
    PHASE_EVALUATING: {PHASE_RUNNING, PHASE_FINALIZED},
    PHASE_FINALIZED: set(),
    PHASE_FAILED: set(),
    PHASE_STOPPED: set(),
}


def _actions(*values: ActionType) -> Set[ActionType]:
    return set(values)


DECISION_ACTIONS = _actions(
    ActionType.WRITE_SPEC,
    ActionType.UPDATE_SPEC,
    ActionType.CHANGE_STRATEGY,
    ActionType.SUBMIT_SEARCH_ASYNC,
    ActionType.STOP_RUN,
    ActionType.TRADITIONAL,
    ActionType.VERIFY,
    ActionType.SUGGEST_STRATEGY,
    ActionType.CODEGEN,
    ActionType.FINALIZE,
)


ALLOWED_ACTIONS: Dict[str, Set[ActionType]] = {
    PHASE_INIT: set(DECISION_ACTIONS),
    PHASE_PLANNING: set(DECISION_ACTIONS),
    PHASE_RUNNING: _actions(),
    PHASE_EVALUATING: set(DECISION_ACTIONS),
    PHASE_FINALIZED: _actions(),
    PHASE_FAILED: _actions(),
    PHASE_STOPPED: _actions(),
}


def can_transition(current: str | None, new_phase: str) -> bool:
    if current is None:
        current = PHASE_INIT
    if current == new_phase:
        return True
    if new_phase in (PHASE_FAILED, PHASE_STOPPED):
        return True
    return new_phase in ALLOWED_TRANSITIONS.get(current, set())


def apply_transition(state: AgentState, new_phase: str) -> bool:
    if can_transition(state.phase, new_phase):
        state.phase = new_phase
        return True
    return False


def action_allowed(phase: str | None, action: ActionType) -> bool:
    if phase is None:
        phase = PHASE_INIT
    return action in ALLOWED_ACTIONS.get(phase, set())


def allowed_actions_for(phase: str | None) -> Iterable[str]:
    if phase is None:
        phase = PHASE_INIT
    return sorted({action.value for action in ALLOWED_ACTIONS.get(phase, set())})
