from typing import Union

from pydantic import BaseModel, Field

from openhands.core.logger import openhands_logger as logger
from openhands.events.action import (
    Action,
    ChangeAgentStateAction,
    MessageAction,
    NullAction,
)
from openhands.events.event import EventSource
from openhands.events.observation import (
    AgentStateChangedObservation,
    NullObservation,
    Observation,
)
from openhands.events.serialization.event import event_to_dict
from openhands.security.invariant.nodes import Function, Message, ToolCall, ToolOutput

TraceElement = Union[Message, ToolCall, ToolOutput, Function]


def get_next_id(trace: list[TraceElement]) -> str:
    used_ids = [el.id for el in trace if type(el) == ToolCall]
    for i in range(1, len(used_ids) + 2):
        if str(i) not in used_ids:
            return str(i)
    return '1'


def get_last_id(
    trace: list[TraceElement],
) -> str | None:
    for el in reversed(trace):
        if type(el) == ToolCall:
            return el.id
    return None


def parse_action(trace: list[TraceElement], action: Action) -> list[TraceElement]:
    next_id = get_next_id(trace)
    inv_trace = []  # type: list[TraceElement]
    if type(action) == MessageAction:
        if action.source == EventSource.USER:
            inv_trace.append(Message(role='user', content=action.content))
        else:
            inv_trace.append(Message(role='assistant', content=action.content))
    elif type(action) in [NullAction, ChangeAgentStateAction]:
        pass
    elif hasattr(action, 'action') and action.action is not None:
        event_dict = event_to_dict(action)
        args = event_dict.get('args', {})
        thought = args.pop('thought', None)
        function = Function(name=action.action, arguments=args)
        if thought is not None:
            inv_trace.append(Message(role='assistant', content=thought))
        inv_trace.append(ToolCall(id=next_id, type='function', function=function))
    else:
        logger.error(f'Unknown action type: {type(action)}')
    return inv_trace


def parse_observation(
    trace: list[TraceElement], obs: Observation
) -> list[TraceElement]:
    last_id = get_last_id(trace)
    if type(obs) in [NullObservation, AgentStateChangedObservation]:
        return []
    elif hasattr(obs, 'content') and obs.content is not None:
        return [ToolOutput(role='tool', content=obs.content, tool_call_id=last_id)]
    else:
        logger.error(f'Unknown observation type: {type(obs)}')
    return []


def parse_element(
    trace: list[TraceElement], element: Action | Observation
) -> list[TraceElement]:
    if isinstance(element, Action):
        return parse_action(trace, element)
    return parse_observation(trace, element)


def parse_trace(trace: list[tuple[Action, Observation]]):
    inv_trace = []  # type: list[TraceElement]
    for action, obs in trace:
        inv_trace.extend(parse_action(inv_trace, action))
        inv_trace.extend(parse_observation(inv_trace, obs))
    return inv_trace


class InvariantState(BaseModel):
    trace: list[TraceElement] = Field(default_factory=list)

    def add_action(self, action: Action):
        self.trace.extend(parse_action(self.trace, action))

    def add_observation(self, obs: Observation):
        self.trace.extend(parse_observation(self.trace, obs))

    def concatenate(self, other: 'InvariantState'):
        self.trace.extend(other.trace)
