
from dataclasses import dataclass, field
from enum import Enum
from typing import Literal, Union

from .schema import ErrorMarker, ToolCall, ToolCallResult


class StepStatus(str, Enum):
    """The state of the end of a step"""

    USER = "USER"
    FINISHED = "FINISHED"
    CONTINUE = "CONTINUE"
    ERROR = "ERROR"


@dataclass
class ActionStepError:
    """The error marker of the action step.

    The most important purpose of this class is to provide a way to mark the error during action step.
    It can be used to implement the retry mechanism and `continue` method of `Runner`.
    """

    message: str
    source: Literal["llm"] = "llm"


@dataclass
class ActionStep:
    """An action step during the execution of an agent."""

    step_status: StepStatus = StepStatus.CONTINUE
    """The status of the step."""

    content: str | None = None
    """The completion content of the step."""

    reasoning_content: str | None = None
    """The reasoning content of the step, maybe generated by the reasoning models."""

    signature: str | None = None
    """The signature of the step, maybe generated by the `claude-thinking` models."""

    tool_calls: list[ToolCall] = field(default_factory=list)
    """The tool call of the step."""

    tool_call_results: list[ToolCallResult] = field(default_factory=list)
    """The tool call result of the step."""

    error_marker: ErrorMarker | None = None
    """When an error occurs during action step, the error marker will be set."""


@dataclass
class UserInputStep:
    """An user input step during the execution of an agent."""

    user_input: str
    """The user input of the step."""

    step_status: StepStatus = StepStatus.USER
    """The status of the step, note this should not be modified."""


StepType = Union[ActionStep, UserInputStep]
"""Represting a step during the execution of an agent."""


@dataclass
class MemoryTurn:
    """A turn in the memory of an agent, contains one or more steps."""

    steps: list[StepType] = field(default_factory=list)
    """The storage steps in the memory turn."""

    @property
    def step_number(self):
        """Return the action step number of the memory turn."""
        return sum(1 for step in self.steps if isinstance(step, ActionStep))

    def is_finished(self) -> bool:
        """Return whether the memory turn is finished."""
        if len(self.steps) == 0:
            return False
        last_step = self.steps[-1]
        return last_step.step_status == StepStatus.FINISHED


@dataclass
class MemoryAgent:
    """Storage the full memory during agent execution."""

    system_instructions: str | None = None
    """The system instructions of the agent."""

    turns: list[MemoryTurn] = field(default_factory=list)
    """The turns in the memory."""

    @property
    def _need_new_turn(self) -> bool:
        if len(self.turns) == 0:
            return True
        last_turn = self.turns[-1]
        return last_turn.is_finished()

    def _add_new_turn(self):
        turn = MemoryTurn()
        # turn.steps.append(UserInputStep(user_input=task))
        self.turns.append(turn)

    def _last_turn(self) -> MemoryTurn:
        """Return the reference of the last turn"""
        # Note this method **should not** do the boundary check.
        return self.turns[-1]

    def insert_user_input(self, user_input: str):
        """Create a new turn and insert a user input step into the last turn.

        Args:
            user_input (str): The user input of this turn.

        Returns:
            MemoryTurn: The reference of the last turn.
        """
        assert self._need_new_turn, "Failed because need_new_turn is False"

        self._add_new_turn()
        last_turn = self._last_turn()
        last_turn.steps.append(UserInputStep(user_input=user_input))
        return last_turn

    def insert_action_step(self, action_step: ActionStep):
        """Insert an action step into the last turn.
        Args:
            action_step (ActionStep): The action step to insert.
        Returns:
            MemoryTurn: The reference of the last turn.
        """
        assert not self._need_new_turn, "Failed because need_new_turn is True"
        last_turn = self._last_turn()
        last_turn.steps.append(action_step)
        return last_turn

    def to_message(
        self, is_claude_thinking: bool = False, default_system_prompt_insert: str = ""
    ) -> list[dict]:
        """Convert the memory into messages for model.

        Reference of format: https://platform.openai.com/docs/guides/function-calling?api-mode=chat
        """
        messages = []

        if self.system_instructions:
            system_prompt = self.system_instructions
            if default_system_prompt_insert:
                system_prompt = default_system_prompt_insert + "\n" + system_prompt

            messages.append(
                {
                    "role": "system",
                    "content": system_prompt,
                }
            )
        for turn in self.turns:
            for step in turn.steps:
                if isinstance(step, UserInputStep):
                    messages.append(
                        {
                            "role": "user",
                            "content": step.user_input,
                        }
                    )
                elif isinstance(step, ActionStep):
                    if step.error_marker:
                        # This means some error occurs during action step.
                        # In this case, we should ignore this action step.
                        continue

                    if step.tool_calls:
                        # The case when some tool is called.
                        assert step.tool_call_results and len(
                            step.tool_call_results
                        ) == len(step.tool_calls)

                        # First, add the `assistant` message includes the tool calls information.
                        m_tc = {
                            "role": "assistant",
                            "content": step.content,
                            "tool_calls": [],
                        }
                        for tc in step.tool_calls:
                            m_tc["tool_calls"].append(
                                {
                                    "id": tc.tool_call_id,
                                    "type": "function",
                                    "function": {
                                        "name": tc.tool_name,
                                        "arguments": tc.arguments,
                                    },
                                }
                            )
                        # Handling the special case for claude-thinking models.
                        if is_claude_thinking:
                            # XXX: Maybe unconditionally add the `reasoning_content` & `signature`?
                            if step.reasoning_content:
                                m_tc["reasoning_content"] = step.reasoning_content
                            if step.signature:
                                m_tc["signature"] = step.signature
                        messages.append(m_tc)

                        # Second, add multiple `tool` message includes the tool results.
                        for tcr in step.tool_call_results:
                            m_tcr = {
                                "role": "tool",
                                "content": tcr.get_content_or_error(),
                                "tool_call_id": tcr.tool_call_id,
                            }
                            messages.append(m_tcr)

                    else:
                        messages.append(
                            {
                                "role": "assistant",
                                "content": step.content,
                            }
                        )
        return messages
