import warnings
from collections.abc import Iterable, Sequence

from agentgym.agent_pipeline.base_pipeline_element import (
    BasePipelineElement,
    DummyEnv,
)
from agentgym.functions_engine.functions_engine import Env, Function
from agentgym.logging import Logger
from agentgym.types import (
    ChatMessage,
    ChatSystemMessage,
    ChatUserMessage,
)


class AgentPipeline(BasePipelineElement):
    def __init__(self, elements: Iterable[BasePipelineElement]) -> None:
        self.elements = elements

    def query(
        self,
        query: str,
        tools: Sequence[Function],
        env: Env = DummyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, Sequence[Function], Env, Sequence[ChatMessage], dict]:
        logger = Logger().get()
        for element in self.elements:
            query, tools, env, messages, extra_args = element.query(query, tools, env, messages, extra_args)
            logger.log(messages)
        return query, tools, env, messages, extra_args


class InitQuery(BasePipelineElement):
    def query(
        self,
        query: str,
        tools: Sequence[Function],
        env: Env = DummyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, Sequence[Function], Env, Sequence[ChatMessage], dict]:
        if len(messages) > 2 or (len(messages) == 1 and messages[0]["role"] != "system"):
            warnings.warn("The query is not being added as the first message or after the system message")
        query_message = ChatUserMessage(role="user", content=query)
        messages = [*messages, query_message]
        return query, tools, env, messages, extra_args


class SystemMessage(BasePipelineElement):
    def __init__(self, system_message: str) -> None:
        self.system_message = system_message

    def query(
        self,
        query: str,
        tools: Sequence[Function],
        env: Env = DummyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, Sequence[Function], Env, Sequence[ChatMessage], dict]:
        if len(messages) > 0:
            raise ValueError("The system message should be the first message")
        system_message = ChatSystemMessage(role="system", content=self.system_message)
        return query, tools, env, [system_message], extra_args
