import os
from typing import List

from tenacity import retry, stop_after_attempt, wait_random_exponential

from ..message import Message
from .base import IntelligenceBackend

# Try to import the cohere package and check whether the API key is set
try:
    import cohere
except ImportError:
    is_cohere_available = False
else:
    if os.environ.get("COHEREAI_API_KEY") is None:
        is_cohere_available = False
    else:
        is_cohere_available = True

# Default config follows the [Cohere documentation](https://cohere-sdk.readthedocs.io/en/latest/cohere.html#cohere.client.Client.chat)
DEFAULT_TEMPERATURE = 0.8
DEFAULT_MAX_TOKENS = 200
DEFAULT_MODEL = "command-xlarge"


class CohereAIChat(IntelligenceBackend):
    """Interface to the Cohere API."""

    stateful = True
    type_name = "cohere-chat"

    def __init__(
        self,
        temperature: float = DEFAULT_TEMPERATURE,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        model: str = DEFAULT_MODEL,
        **kwargs,
    ):
        super().__init__(
            temperature=temperature, max_tokens=max_tokens, model=model, **kwargs
        )

        self.temperature = temperature
        self.max_tokens = max_tokens
        self.model = model

        assert (
            is_cohere_available
        ), "Cohere package is not installed or the API key is not set"
        self.client = cohere.Client(os.environ.get("COHEREAI_API_KEY"))

        # Stateful variables
        self.session_id = None  # The session id for the last conversation
        self.last_msg_hash = (
            None  # The hash of the last message of the last conversation
        )

    def reset(self):
        self.session_id = None
        self.last_msg_hash = None

    @retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60))
    def _get_response(self, new_message: str, persona_prompt: str):
        response = self.client.chat(
            new_message,
            persona_prompt=persona_prompt,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
            session_id=self.session_id,
        )

        self.session_id = response.session_id  # Update the session id
        return response.reply

    def query(
        self,
        agent_name: str,
        role_desc: str,
        history_messages: List[Message],
        global_prompt: str = None,
        request_msg: Message = None,
        *args,
        **kwargs,
    ) -> str:
        """
        Format the input and call the Cohere API.

        args:
            agent_name: the name of the agent
            role_desc: the description of the role of the agent
            env_desc: the description of the environment
            history_messages: the history of the conversation, or the observation for the agent
            request_msg: the request for the CohereAI
        """
        # Find the index of the last message of the last conversation
        new_message_start_idx = 0
        if self.last_msg_hash is not None:
            for i, message in enumerate(history_messages):
                if message.msg_hash == self.last_msg_hash:
                    new_message_start_idx = i + 1
                    break

        new_messages = history_messages[new_message_start_idx:]
        assert len(new_messages) > 0, "No new messages found (this should not happen)"

        new_conversations = []
        for message in new_messages:
            if message.agent_name != agent_name:
                # Since there are more than one player, we need to distinguish between the players
                new_conversations.append(f"[{message.agent_name}]: {message.content}")

        if request_msg:
            new_conversations.append(
                f"[{request_msg.agent_name}]: {request_msg.content}"
            )

        # Concatenate all new messages into one message because the Cohere API only accepts one message
        new_message = "\n".join(new_conversations)
        persona_prompt = f"Environment:\n{global_prompt}\n\nYour role:\n{role_desc}"

        response = self._get_response(new_message, persona_prompt)

        # Only update the last message hash if the API call is successful
        self.last_msg_hash = new_messages[-1].msg_hash

        return response
