import warnings
from collections.abc import Sequence

from agentdojo.agent_pipeline.base_pipeline_element import BasePipelineElement
from agentdojo.functions_runtime import EmptyEnv, Env, FunctionsRuntime
from agentdojo.types import ChatMessage, ChatSystemMessage, ChatUserMessage, text_content_block_from_string


class InitQuery(BasePipelineElement):
    """Initializes the pipeline execution with the user query, by adding a
    [`ChatUserMessage`][agentdojo.types.ChatUserMessage] to the messages list."""

    def query(
        self,
        query: str,
        runtime: FunctionsRuntime,
        env: Env = EmptyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, FunctionsRuntime, Env, Sequence[ChatMessage], dict]:
        if len(messages) > 2 or (len(messages) == 1 and messages[0]["role"] != "system"):
            warnings.warn("The query is not being added as the first message or after the system message")
        query_message = ChatUserMessage(role="user", content=[text_content_block_from_string(query)])
        messages = [*messages, query_message]
        return query, runtime, env, messages, extra_args


class SystemMessage(BasePipelineElement):
    """Adds a system message to the messages list (as a [`ChatSystemMessage`][agentdojo.types.ChatSystemMessage]).

    Args:
        system_message: the content of the system message.
    """

    def __init__(self, system_message: str) -> None:
        self.system_message = system_message

    def query(
        self,
        query: str,
        runtime: FunctionsRuntime,
        env: Env = EmptyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, FunctionsRuntime, Env, Sequence[ChatMessage], dict]:
        if len(messages) > 0:
            raise ValueError("The system message should be the first message")
        system_message = ChatSystemMessage(role="system", content=[text_content_block_from_string(self.system_message)])
        return query, runtime, env, [system_message], extra_args
