import abc
import json
import re
from collections.abc import Sequence

from openai.types.chat import ChatCompletionMessage

from agentgym.agent_pipeline.base_pipeline_element import (
    DummyEnv,
)
from agentgym.agent_pipeline.llms.openai_llm import (
    OpenAILLM,
    _message_to_openai,
    chat_completion_request,
)
from agentgym.ast_utils import (
    ASTParsingError,
    create_python_function_from_tool_call,
    parse_tool_calls_from_python_function,
)
from agentgym.functions_engine.functions_engine import (
    Env,
    Function,
)
from agentgym.types import (
    ChatAssistantMessage,
    ChatMessage,
    ChatSystemMessage,
    ChatToolResultMessage,
    ChatUserMessage,
)


class InvalidModelOutputError(Exception): ...


class BasePromptingLLM(OpenAILLM):
    _tool_calling_prompt: str
    _MAX_ATTEMPTS = 3

    @abc.abstractmethod
    def _make_tools_prompt(
        self, system_message: ChatSystemMessage | None, tools: Sequence[Function]
    ) -> ChatSystemMessage | None:
        raise NotImplementedError()

    @abc.abstractmethod
    def _parse_model_output(self, message: ChatCompletionMessage) -> ChatAssistantMessage:
        raise NotImplementedError()

    def _get_system_message(
        self, messages: Sequence[ChatMessage]
    ) -> tuple[ChatSystemMessage | None, Sequence[ChatMessage]]:
        if len(messages) == 0:
            return None, []
        if messages[0]["role"] == "system":
            return messages[0], messages[1:]
        return None, messages

    @abc.abstractmethod
    def _tool_message_to_user_message(self, tool_message: ChatToolResultMessage) -> ChatUserMessage: ...

    def query(
        self,
        query: str,
        tools: Sequence[Function],
        env: Env = DummyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, Sequence[Function], Env, Sequence[ChatMessage], dict]:
        adapted_messages = [
            self._tool_message_to_user_message(message) if message["role"] == "tool" else message
            for message in messages
        ]
        openai_messages = [_message_to_openai(message) for message in adapted_messages]
        system_message, other_messages = self._get_system_message(messages)
        system_message = self._make_tools_prompt(system_message, tools)
        if system_message is not None:
            openai_messages = [system_message, *openai_messages]
        completion = chat_completion_request(self.client, self.model, openai_messages, [], self.temperature)
        output = ChatAssistantMessage(
            role="assistant",
            content=completion.choices[0].message.content or "",
            tool_calls=[],
        )
        if len(tools) == 0 or "<function-call>" not in (output["content"] or ""):
            return query, tools, env, [*messages, output], extra_args
        for _ in range(self._MAX_ATTEMPTS):
            try:
                output = self._parse_model_output(completion.choices[0].message)
                break
            except (InvalidModelOutputError, ASTParsingError) as e:
                error_message = ChatUserMessage(role="user", content=f"Invalid function calling output: {e!s}")
                completion = chat_completion_request(
                    self.client,
                    self.model,
                    [*openai_messages, _message_to_openai(error_message)],
                    [],
                    self.temperature,
                )
        return query, tools, env, [*messages, output], extra_args


class PromptingLLM(BasePromptingLLM):
    """Llama 3 for function calling based on
    https://github.com/hamelsmu/replicate-examples/blob/master/cog-vllm-tools/predict.py"""

    _tool_calling_prompt = """You are a hepful assistant. You are given a task and a set of possible functions inside <function-definitions> tags.
Calling these functions are optional. Carefully consider the question and determine if one or more functions can be used to answer the question. Place your thoughts and reasoning behind your decision in <function-thoughts> tags.
If the given question lacks the parameters required by the function, point it out in <function-thoughts> tags. Below is a list of function definitions:
<function-definitions>
{funcs}
</function-definitions>

If you wish to call a particular function, specify the name of the function and any arguments in a way that conforms to that function's schema inside <function-call> tags.
Function calls should be in this format: <function-thoughts>Calling func1 would be helpful because of ...</function-thoughts>\
<function-call>[function1(a="1", b=3), function2(foo=["a", "b"], bar=None)]</function-call>, WITHOUT any answer. Pass the arguments in correct format, i.e., \
strings should be enclosed in quotes, lists should be enclosed in square brackets, integers should have no quotes, etc. \
If you do not wish to call any functions, say so in the <function-thoughts> tags followed by <function-call>[]</function-call><answer>...</answer>
If no tools are provided, act as if no tools are available to you, but still provide <function-call>[]</function-call> as part of your output.

If and only if NO function calls are made, answer the question to the best of your ability inside <answer> tags.  If you are unsure of the answer, say so in <answer> tags.

The user will provide the output of the function calls in the <function-result> tags. The function call that output the given result is provided in <function-call> tags. Give your reply based on the output of the function calls. If no function calls are made, the user will provide the answer without any tags.
If the tool returned an error, then the user will provide the error message in <function-error> tags. The function call that returned the error is provided in <function-call> tags. If the tool returns an error, call the tool again with the correct parameters. If the tool returns an error again, without asking the user. You might need to call other tools to solve the task. Be persistent!

Sometimes, you may need to call multiple functions in multiple rounds to get the desired output, based on the results of the previous function calls. When this is the case, then use the <function-thoughts> tags to explain your reasoning for calling the next function, and the <function-call> tags to specify the next function call. \

Once you receive the result of one tool, think if you need to call another tool in the <function-thoughts> tags. If you do, then call the next tool. If you are done and do not need to call another tool, then do explicitly give the answer in <answer> tags. Tell the answer to the user in the message if you don't need other tools.

If you think you will need to call multiple tools in multiple stages, but you don't have yet the information to call all of them, then wait to receive the information back from the tools you can already call."""

    def _make_tools_prompt(
        self, system_message: ChatSystemMessage | None, tools: Sequence[Function]
    ) -> ChatSystemMessage | None:
        if len(tools) == 0:
            return system_message
        tools_docs = ""
        system_prompt = system_message["content"] if system_message is not None else ""
        for index, tool in enumerate(tools, start=1):
            tool_dict = {
                "name": tool.name,
                "description": tool.description,
                "parameters": tool.parameters.model_json_schema(),
            }
            tools_docs += f"<function-{index}>\n"
            tools_docs += json.dumps(tool_dict, indent=4)
            tools_docs += f"\n</function-{index}>\n\n"
        tool_calling_prompt = self._tool_calling_prompt.format(funcs=tools_docs)
        message_content = f"{tool_calling_prompt}\n{system_prompt}"
        return ChatSystemMessage(role="system", content=message_content)

    def _parse_model_output(self, message: ChatCompletionMessage) -> ChatAssistantMessage:
        return parse_model_output(message)

    def _tool_message_to_user_message(self, tool_message: ChatToolResultMessage) -> ChatUserMessage:
        return tool_message_to_user_message(tool_message)


def tool_message_to_user_message(
    tool_message: ChatToolResultMessage,
) -> ChatUserMessage:
    function_call_signature = create_python_function_from_tool_call(tool_message["tool_call"])
    function_call = f"<function-call>{function_call_signature}</function-call>"
    if tool_message["error"] is None:
        tool_result = f"<function-result>{tool_message['content']}</function-result>"
    else:
        tool_result = f"<function-error>{tool_message['error']}</function-error>"
    return ChatUserMessage(
        role="user",
        content=f"{function_call}{tool_result}",
    )


def parse_model_output(message: ChatCompletionMessage) -> ChatAssistantMessage:
    if message.content is None:
        return ChatAssistantMessage(role="assistant", content="", tool_calls=None)
    tool_call_pattern = re.compile(r"<function-call>(.*?)</function-call>", re.DOTALL)
    tool_call_match = tool_call_pattern.search(message.content)

    # Extract the function call content
    tool_call_content = tool_call_match.group(1) if tool_call_match else "[]"
    # Remove the function call section from the original text
    outside_content = (
        re.sub(r"<function-call>.*?</function-call>", "", message.content, flags=re.DOTALL)
        .replace("<function-thoughts>", "")
        .replace("</function-thoughts>", "")
        .strip()
    )
    outside_content = re.sub(r"\n\s*\n", "\n\n", outside_content)
    tool_calls = parse_tool_calls_from_python_function(tool_call_content)
    for tool_call in tool_calls:
        args = {
            arg_name: ("..." if arg_value == Ellipsis else arg_value)
            for arg_name, arg_value in tool_call["args"].items()
        }
        tool_call["args"] = args
    if len(tool_calls) == 0:
        answer_pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
        answer_match = answer_pattern.search(outside_content)
        if answer_match is None:
            raise InvalidModelOutputError("The answer should be in <answer> tags if no tool calls are provided.")
        outside_content = answer_match.group(1)

    return ChatAssistantMessage(role="assistant", content=outside_content, tool_calls=tool_calls)
