import warnings
from collections.abc import Sequence
from typing import overload

import cohere
from cohere.types import Message as CohereMessage
from cohere.types import (
    NonStreamedChatResponse,
    ToolParameterDefinitionsValue,
)

try:
    from cohere.types import ChatbotMessage, SystemMessage, ToolMessage, UserMessage
except ImportError:
    # Compatibility with cohere<5.9.3
    warnings.warn(
        DeprecationWarning("You are using an old version of the Cohere SDK. Please upgrade to the latest version.")
    )
    from cohere.types import Message_Chatbot as ChatbotMessage  # type: ignore
    from cohere.types import Message_System as SystemMessage  # type: ignore
    from cohere.types import Message_Tool as ToolMessage  # type: ignore
    from cohere.types import Message_User as UserMessage  # type: ignore
from cohere.types import Tool as CohereTool
from cohere.types import ToolCall as CohereToolCall
from cohere.types import ToolResult as CohereToolResult
from tenacity import retry, retry_if_not_exception_type, stop_after_attempt, wait_random_exponential

from agentdojo.agent_pipeline.base_pipeline_element import BasePipelineElement
from agentdojo.functions_runtime import EmptyEnv, Env, Function, FunctionCall, FunctionsRuntime
from agentdojo.types import (
    ChatAssistantMessage,
    ChatMessage,
    ChatSystemMessage,
    ChatToolResultMessage,
    ChatUserMessage,
    text_content_block_from_string,
)


@overload
def _message_to_cohere(message: ChatToolResultMessage) -> ToolMessage: ...


@overload
def _message_to_cohere(message: ChatUserMessage) -> UserMessage: ...


@overload
def _message_to_cohere(message: ChatAssistantMessage) -> ChatbotMessage: ...


@overload
def _message_to_cohere(message: ChatSystemMessage) -> SystemMessage: ...


@overload
def _message_to_cohere(message: ChatMessage) -> CohereMessage: ...


def _content_parts_to_cohere_message(assistant_message: ChatAssistantMessage) -> str:
    if assistant_message["content"] is None:
        raise ValueError("Content can't be `None` for Cohere")
    if len(assistant_message["content"]) != 1:
        raise ValueError("Number of content blocks must be exactly one for Cohere")
    if assistant_message["content"][0]["content"] is None:
        raise ValueError("Content can't be `None` for Cohere")
    return assistant_message["content"][0]["content"]


def _message_to_cohere(message: ChatMessage) -> CohereMessage:
    match message["role"]:
        case "user":
            return UserMessage(
                role="USER",
                message=message["content"][0]["content"],
            )
        case "tool":
            tool_call = CohereToolCall(
                name=message["tool_call"].function,
                parameters=dict(message["tool_call"].args) if message["tool_call"].args is not None else {},
            )
            tool_output = {"result": message["error"] or message["content"]}
            tool_result = CohereToolResult(
                call=tool_call,
                outputs=[tool_output],
            )
            return ToolMessage(role="TOOL", tool_results=[tool_result])
        case "assistant" if message["content"] is not None:
            tool_calls = [
                CohereToolCall(
                    name=tool_call.function,
                    parameters=dict(tool_call.args),
                )
                for tool_call in message["tool_calls"] or []
            ]
            return ChatbotMessage(
                role="CHATBOT",
                message=_content_parts_to_cohere_message(message),
                tool_calls=tool_calls or None,
            )
        case "system":
            return SystemMessage(
                role="SYSTEM",
                message=message["content"][0]["content"],
            )
        case _:
            raise ValueError(f"Invalid message role for Cohere: {message['role']}")


def _merge_tool_result_messages(messages: Sequence[CohereMessage]) -> Sequence[CohereMessage]:
    messages = list(messages)
    i = 0
    while i < len(messages) - 1:
        message_1 = messages[i]
        message_2 = messages[i + 1]
        if isinstance(message_1, ToolMessage) and isinstance(message_2, ToolMessage):
            if message_1.tool_results is None:  # type: ignore this is due to the conditional import
                message_1.tool_results = []
            message_1.tool_results.extend(message_2.tool_results or [])  # type: ignore this is due to the conditional import
            messages.pop(i + 1)
        else:
            i += 1
    return messages


def _conversation_to_cohere(
    messages: Sequence[ChatMessage],
) -> tuple[str, list[CohereToolResult], Sequence[CohereMessage]]:
    if len(messages) == 0:
        raise ValueError("There should be at least one message in the conversation")
    cohere_messages = [_message_to_cohere(message) for message in messages]
    cohere_messages = _merge_tool_result_messages(cohere_messages)
    *other_messages, last_message = cohere_messages
    if not isinstance(last_message, UserMessage | ToolMessage):
        raise ValueError("The last message should be from the user or a tool")
    if isinstance(last_message, ToolMessage):
        tool_results = last_message.tool_results or []  # type: ignore this is due to the conditional import
        last_message_content = ""
    else:
        tool_results = []
        last_message_content = last_message.message  # type: ignore this is due to the conditional import
    return last_message_content, tool_results, other_messages


JSON_TO_PYTHON_TYPES = {
    "string": "str",
    "number": "float",
    "boolean": "bool",
    "integer": "int",
    "array": "List",
    "object": "Dict",
}


def _convert_to_cohere_tool(tool: Function) -> CohereTool:
    """Re-implementation with more types of langchain_cohere.cohere_agent._convert_to_cohere_tool"""
    parameter_definitions = {}
    parameters = tool.parameters.model_json_schema()
    for param_name, param_definition in parameters["properties"].items():
        # Handle Union types
        if "anyOf" in param_definition:
            inner_types = [t["type"] for t in param_definition["anyOf"]]
            param_type = f"Union[{', '.join(inner_types)}]"
        # Handle enums
        elif "allOf" in param_definition:
            param_type = ""
            for all_of_definition in param_definition["allOf"]:
                if "$ref" in all_of_definition:
                    if "$defs" not in parameters:
                        raise ValueError(f"Can't resolve $ref {all_of_definition['$ref']}")
                    enum = all_of_definition["$ref"].split("/")[-1]
                    param_type = JSON_TO_PYTHON_TYPES[parameters["$defs"][enum]["type"]]
                    param_definition["description"] += f" (Possible values: {parameters['$defs'][enum]['enum']})"
                    break
            if not param_type:
                raise ValueError(f"Enum not found for parameter {param_name} in tool {tool.name}")
        elif "type" not in param_definition:
            raise ValueError(f"Type not found for parameter {param_name}")
        else:
            param_type = JSON_TO_PYTHON_TYPES[param_definition["type"]]
        parameter_definitions[param_name] = ToolParameterDefinitionsValue(
            description=param_definition.get("description"),
            type=param_type,
            required="default" not in param_definition,
        )
    return CohereTool(
        name=tool.name,
        description=tool.description,
        parameter_definitions=parameter_definitions,
    )


def _cohere_to_assistant_message(response: NonStreamedChatResponse) -> ChatAssistantMessage:
    tool_calls = [
        FunctionCall(function=tool_call.name, args=tool_call.parameters or {})
        for tool_call in response.tool_calls or []
    ]
    return ChatAssistantMessage(
        role="assistant",
        content=[text_content_block_from_string(response.text)],
        tool_calls=tool_calls,
    )


@retry(
    wait=wait_random_exponential(multiplier=1, max=40),
    stop=stop_after_attempt(5),
    reraise=True,
    retry=retry_if_not_exception_type(cohere.BadRequestError),
)
def chat_completion_request(
    client: cohere.Client,
    model: str,
    message: str,
    tool_results: Sequence[CohereToolResult],
    chat_history: Sequence[CohereMessage] | None,
    tools: Sequence[CohereTool],
    temperature: float | None = 0.0,
    seed: int | None = None,
) -> NonStreamedChatResponse:
    return client.chat(
        model=model,
        message=message,
        tool_results=tool_results or None,
        chat_history=chat_history,
        tools=tools or None,
        force_single_step=False,
        temperature=temperature,
        seed=seed,
    )


class CohereLLM(BasePipelineElement):
    """LLM pipeline element using Cohere's API.

    Args:
        client: Cohere client
        model: Model name
        temperature: Temperature for sampling
        seed: Seed for sampling
    """

    def __init__(self, client: cohere.Client, model: str, temperature: float = 0.0, seed: int | None = None) -> None:
        self.client = client
        self.model = model
        self.temperature = temperature
        self.seed = seed

    def query(
        self,
        query: str,
        runtime: FunctionsRuntime,
        env: Env = EmptyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, FunctionsRuntime, Env, Sequence[ChatMessage], dict]:
        message, tool_results, chat_history = _conversation_to_cohere(messages)
        cohere_tools = [_convert_to_cohere_tool(tool) for tool in runtime.functions.values()]
        completion = chat_completion_request(
            self.client,
            self.model,
            message,
            tool_results,
            chat_history,
            cohere_tools,
            temperature=self.temperature,
            seed=self.seed,
        )
        output = _cohere_to_assistant_message(completion)
        messages = [*messages, output]
        return query, runtime, env, messages, extra_args
