from collections.abc import Sequence
from typing import overload

import cohere
from cohere.types import (
    Message as CohereMessage,
)
from cohere.types import (
    Message_Chatbot,
    Message_System,
    Message_Tool,
    Message_User,
    NonStreamedChatResponse,
    ToolParameterDefinitionsValue,
)
from cohere.types import Tool as CohereTool
from cohere.types import ToolCall as CohereToolCall
from cohere.types import ToolResult as CohereToolResult
from tenacity import (
    retry,
    retry_if_not_exception_type,
    stop_after_attempt,
    wait_random_exponential,
)

from agentgym.agent_pipeline.base_pipeline_element import (
    BasePipelineElement,
    DummyEnv,
)
from agentgym.functions_engine.functions_engine import (
    Env,
    Function,
    FunctionCall,
)
from agentgym.types import (
    ChatAssistantMessage,
    ChatMessage,
    ChatSystemMessage,
    ChatToolResultMessage,
    ChatUserMessage,
)


@overload
def _message_to_cohere(message: ChatToolResultMessage) -> Message_Tool: ...


@overload
def _message_to_cohere(message: ChatUserMessage) -> Message_User: ...


@overload
def _message_to_cohere(message: ChatAssistantMessage) -> Message_Chatbot: ...


@overload
def _message_to_cohere(message: ChatSystemMessage) -> Message_System: ...


@overload
def _message_to_cohere(message: ChatMessage) -> CohereMessage: ...


def _message_to_cohere(message: ChatMessage) -> CohereMessage:
    match message["role"]:
        case "user":
            return Message_User(
                role="USER",
                message=message["content"],
            )
        case "tool":
            tool_call = CohereToolCall(
                name=message["tool_call"]["function"],
                parameters=dict(message["tool_call"]["args"]) if message["tool_call"]["args"] is not None else {},
            )
            tool_output = {"result": message["error"] or message["content"]}
            tool_result = CohereToolResult(
                call=tool_call,
                outputs=[tool_output],
            )
            return Message_Tool(role="TOOL", tool_results=[tool_result])
        case "assistant" if message["content"] is not None:
            tool_calls = [
                CohereToolCall(
                    name=tool_call["function"],
                    parameters=dict(tool_call["args"]) if tool_call["args"] is not None else {},
                )
                for tool_call in message["tool_calls"] or []
            ]
            return Message_Chatbot(
                role="CHATBOT",
                message=message["content"],
                tool_calls=tool_calls or None,
            )
        case "system":
            return Message_System(
                role="SYSTEM",
                message=message["content"],
            )
        case _:
            raise ValueError(f"Invalid message role for Cohere: {message['role']}")


def _merge_tool_result_messages(
    messages: Sequence[CohereMessage],
) -> Sequence[CohereMessage]:
    messages = list(messages)
    i = 0
    while i < len(messages) - 1:
        message_1 = messages[i]
        message_2 = messages[i + 1]
        if isinstance(message_1, Message_Tool) and isinstance(message_2, Message_Tool):
            if message_1.tool_results is None:
                message_1.tool_results = []
            message_1.tool_results.extend(message_2.tool_results or [])
            messages.pop(i + 1)
        else:
            i += 1
    return messages


def _conversation_to_cohere(
    messages: Sequence[ChatMessage],
) -> tuple[str, list[CohereToolResult], Sequence[CohereMessage]]:
    if len(messages) == 0:
        raise ValueError("There should be at least one message in the conversation")
    cohere_messages = [_message_to_cohere(message) for message in messages]
    cohere_messages = _merge_tool_result_messages(cohere_messages)
    *other_messages, last_message = cohere_messages
    if not isinstance(last_message, Message_User | Message_Tool):
        raise ValueError("The last message should be from the user or a tool")
    if isinstance(last_message, Message_Tool):
        tool_results = last_message.tool_results or []
        last_message_content = ""
    else:
        tool_results = []
        last_message_content = last_message.message
    return last_message_content, tool_results, other_messages


JSON_TO_PYTHON_TYPES = {
    "string": "str",
    "number": "float",
    "boolean": "bool",
    "integer": "int",
    "array": "List",
    "object": "Dict",
}


def _convert_to_cohere_tool(
    tool: Function,
) -> CohereTool:
    """Re-implementation with more types of langchain_cohere.cohere_agent._convert_to_cohere_tool"""
    parameter_definitions = {}
    parameters = tool.parameters.model_json_schema()
    for param_name, param_definition in parameters["properties"].items():
        # Handle Union types
        if "anyOf" in param_definition:
            inner_types = [t["type"] for t in param_definition["anyOf"]]
            param_type = f"Union[{', '.join(inner_types)}]"
        # Handle enums
        elif "allOf" in param_definition:
            param_type = ""
            for all_of_definition in param_definition["allOf"]:
                if "$ref" in all_of_definition:
                    if "$defs" not in parameters:
                        raise ValueError(f"Can't resolve $ref {all_of_definition['$ref']}")
                    enum = all_of_definition["$ref"].split("/")[-1]
                    param_type = JSON_TO_PYTHON_TYPES[parameters["$defs"][enum]["type"]]
                    param_definition["description"] += f" (Possible values: {parameters['$defs'][enum]['enum']})"
                    break
            if not param_type:
                raise ValueError(f"Enum not found for parameter {param_name} in tool {tool.name}")
        elif "type" not in param_definition:
            raise ValueError(f"Type not found for parameter {param_name}")
        else:
            param_type = JSON_TO_PYTHON_TYPES[param_definition["type"]]
        parameter_definitions[param_name] = ToolParameterDefinitionsValue(
            description=param_definition.get("description"),
            type=param_type,
            required="default" not in param_definition,
        )
    return CohereTool(
        name=tool.name,
        description=tool.description,
        parameter_definitions=parameter_definitions,
    )


def _cohere_to_assistant_message(
    response: NonStreamedChatResponse,
) -> ChatAssistantMessage:
    tool_calls = [
        FunctionCall(function=tool_call.name, args=tool_call.parameters or {}) for tool_call in response.tool_calls or []
    ]
    return ChatAssistantMessage(
        role="assistant",
        content=response.text,
        tool_calls=tool_calls,
    )


@retry(
    wait=wait_random_exponential(multiplier=1, max=40),
    stop=stop_after_attempt(5),
    reraise=True,
    retry=retry_if_not_exception_type(cohere.BadRequestError),
)
def chat_completion_request(
    client: cohere.Client,
    model: str,
    message: str,
    tool_results: Sequence[CohereToolResult],
    chat_history: Sequence[CohereMessage] | None,
    tools: Sequence[CohereTool],
    temperature: float | None = 0.0,
    seed: float | None = None,
) -> NonStreamedChatResponse:
    return client.chat(
        model=model,
        message=message,
        tool_results=tool_results or None,
        chat_history=chat_history,
        tools=tools or None,
        force_single_step=False,
        temperature=temperature,
        seed=seed,
    )


class CohereLLM(BasePipelineElement):
    def __init__(
        self,
        client: cohere.Client,
        model: str,
        temperature: float = 0.0,
        seed: float | None = None,
    ) -> None:
        self.client = client
        self.model = model
        self.temperature = temperature
        self.seed = seed

    def query(
        self,
        query: str,
        tools: Sequence[Function],
        env: Env = DummyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, Sequence[Function], Env, Sequence[ChatMessage], dict]:
        message, tool_results, chat_history = _conversation_to_cohere(messages)
        cohere_tools = [_convert_to_cohere_tool(tool) for tool in tools]
        completion = chat_completion_request(
            self.client,
            self.model,
            message,
            tool_results,
            chat_history,
            cohere_tools,
            temperature=self.temperature,
            seed=self.seed,
        )
        output = _cohere_to_assistant_message(completion)
        messages = [*messages, output]
        return query, tools, env, messages, extra_args
