import json
from typing import Any

from openai.types.chat import ChatCompletionMessageToolCall
from PIL import Image

from crab import Action, ActionOutput, BackendModel, BackendOutput, MessageType
from crab.utils.common import base64_to_image

try:
    from camel.agents import ChatAgent
    from camel.configs import ChatGPTConfig
    from camel.messages import BaseMessage
    from camel.models import ModelFactory
    from camel.toolkits import OpenAIFunction
    from camel.types.enums import ModelPlatformType, ModelType

    CAMEL_ENABLED = True
except ImportError:
    CAMEL_ENABLED = False


def _find_model_platform_type(model_platform_name: str) -> "ModelPlatformType":
    for platform in ModelPlatformType:
        if platform.value.lower() == model_platform_name.lower():
            return platform
    all_models = [platform.value for platform in ModelPlatformType]
    raise ValueError(
        f"Model {model_platform_name} not found. Supported models are {all_models}"
    )


def _find_model_type(model_name: str) -> "str | ModelType":
    for model in ModelType:
        if model.value.lower() == model_name.lower():
            return model
    return model_name


def _convert_action_to_schema(
    action_space: list[Action] | None,
) -> "list[OpenAIFunction] | None":
    if action_space is None:
        return None
    return [OpenAIFunction(action.entry) for action in action_space]


def _convert_tool_calls_to_action_list(
    tool_calls: list[ChatCompletionMessageToolCall] | None,
) -> list[ActionOutput] | None:
    if tool_calls is None:
        return None

    return [
        ActionOutput(
            name=call.function.name,
            arguments=json.loads(call.function.arguments),
        )
        for call in tool_calls
    ]


class CamelModel(BackendModel):
    def __init__(
        self,
        model: str,
        model_platform: str,
        parameters: dict[str, Any] | None = None,
        history_messages_len: int = 0,
    ) -> None:
        if not CAMEL_ENABLED:
            raise ImportError("Please install camel-ai to use CamelModel")
        self.parameters = parameters or {}
        # TODO: a better way?
        self.model_type = _find_model_type(model)
        self.model_platform_type = _find_model_platform_type(model_platform)
        self.client: ChatAgent | None = None
        self.token_usage = 0

        super().__init__(
            model,
            parameters,
            history_messages_len,
        )

    def get_token_usage(self) -> int:
        return self.token_usage

    def reset(self, system_message: str, action_space: list[Action] | None) -> None:
        action_schema = _convert_action_to_schema(action_space)
        config = self.parameters.copy()
        if action_schema is not None:
            config["tool_choice"] = "required"
            config["tools"] = action_schema

        chatgpt_config = ChatGPTConfig(
            **config,
        )
        backend_model = ModelFactory.create(
            self.model_platform_type,
            self.model_type,
            model_config_dict=chatgpt_config.as_dict(),
        )
        sysmsg = BaseMessage.make_assistant_message(
            role_name="Assistant",
            content=system_message,
        )
        self.client = ChatAgent(
            model=backend_model,
            system_message=sysmsg,
            external_tools=action_schema,
            message_window_size=self.history_messages_len,
        )
        self.token_usage = 0

    def chat(self, messages: list[tuple[str, MessageType]]) -> BackendOutput:
        # TODO: handle multiple text messages after message refactoring
        image_list: list[Image.Image] = []
        content = ""
        for message in messages:
            if message[1] == MessageType.IMAGE_JPG_BASE64:
                image = base64_to_image(message[0])
                image_list.append(image)
            else:
                content = message[0]
        usermsg = BaseMessage.make_user_message(
            role_name="User",
            content=content,
            image_list=image_list,
        )
        response = self.client.step(usermsg)
        self.token_usage += response.info["usage"]["total_tokens"]
        tool_call_request = response.info.get("external_tool_request")

        return BackendOutput(
            message=response.msg.content,
            action_list=_convert_tool_calls_to_action_list([tool_call_request]),
        )
