import time
import uuid
from datetime import datetime
from typing import Callable, Optional

import pandas as pd
from langchain_community.embeddings import SentenceTransformerEmbeddings
from sqlalchemy import select
from tqdm import tqdm

from synthetic_agents.app.chat_config import LanguageConfig, MemoryConfig
from synthetic_agents.callback.token import TokenCallback
from synthetic_agents.common.config import settings
from synthetic_agents.common.constants import (
    AI_CHAT_AAI_INTERVIEWER_AGENT_TYPE,
    AI_CHAT_USER_AGENT_TYPE,
)
from synthetic_agents.database.config import get_db, get_vector_db
from synthetic_agents.database.entity.agent import Agent as DBAgent
from synthetic_agents.database.entity.application_agent import ApplicationAgent
from synthetic_agents.database.entity.application_message import ApplicationMessage
from synthetic_agents.database.entity.chat import Chat
from synthetic_agents.database.entity.working_memory import WorkingMemory
from synthetic_agents.model.aai_interviewer import AAIInterviewer
from synthetic_agents.model.agent import Agent
from synthetic_agents.model.entity.world import TextMessage, WorldState
from synthetic_agents.model.memory import Memory
from synthetic_agents.prompt.builder import PromptBuilder
from synthetic_agents.prompt.loader import PersistedAgentPromptTemplateLoader


class PlayableChat:
    """
    This class represents a playable chat. One can use this class to play a conversation between a
    user and a coach agent. To use this class, a chat with the same ID must have been previously
    created and agents assigned to it.
    """

    def __init__(
        self,
        chat_id: int,
        session_id: Optional[str] = None,
        before_turn_callback: Optional[Callable] = None,
        after_turn_callback: Optional[Callable] = None,
        stream_tokens: bool = False,
        seconds_between_messages: float = 1.0,
        llm_name: Optional[str] = None,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
    ):
        """
        Creates a playable chat.

        :param chat_id: the ID of the chat to play.
        :param session_id: an optional ID of the chat session. If not provided, a random session
            ID will be generated.
        :param before_turn_callback: an optional function to call before each chat turn.
        :param after_turn_callback: an optional function to call after each chat turn.
        :param stream_tokens: whether to stream tokens as they are produced.
        :param seconds_between_messages: the number of seconds to wait between messages.
        :param llm_name: an optional LLM to use. If provided, this will override the default
            LLM assigned to the chat on its creation.
        :param temperature: an optional temperature to use. If provided, this will override the
            default temperature assigned to the chat on its creation.
        :param top_p: an optional top_p to use. If provided, this will override the default
            top_p assigned to the chat on its creation.
        """
        self.chat_id = chat_id
        self.session_id = session_id
        self.before_turn_callback = before_turn_callback
        self.after_turn_callback = after_turn_callback
        self.stream_tokens = stream_tokens
        self.seconds_between_messages = seconds_between_messages
        self.llm_name = llm_name
        self.temperature = temperature
        self.top_p = top_p

        self.messages_: list[TextMessage] = []

        self._user: Optional[Agent] = None
        self._coach: Optional[Agent] = None
        self._last_speaker: Optional[Agent] = None
        self._messages_by_agent: dict[int, list[TextMessage]] = {}

        if self.session_id is None:
            self.create_new_session()

    def initialize_chat(self):
        """
        Initializes chat and agents from information persisted in the DB.
        """
        db = next(get_db())
        db_chat = db.execute(select(Chat).where(Chat.application_id == self.chat_id)).first()[0]

        if db_chat is None:
            raise Exception(f"Chat ({self.chat_id}) not found in the database.")

        for db_agent in db_chat.agents:
            chat_attributes = db_chat.attributes[str(db_agent.agent_id)]
            memory_config = MemoryConfig(
                **{key: chat_attributes[key] for key in MemoryConfig.__annotations__}
            )
            language_config = LanguageConfig(
                **{key: chat_attributes[key] for key in LanguageConfig.__annotations__}
            )
            self._initialize_agent(
                db_agent=db_agent,
                memory_config=memory_config,
                language_config=language_config,
            )

        db.close()

    def _initialize_agent(
        self,
        db_agent: DBAgent,
        memory_config: MemoryConfig,
        language_config: LanguageConfig,
    ):
        """
        Initializes an agent from information persisted in the DB.

        :param db_agent: persisted agent entity.
        :param memory_config: configurations of the agent's memory model.
        :param language_config: configurations of the agent's language model.
        """

        # TODO: Change this is the future to be more flexible. Maybe by having a deterministic
        #  agent and the series of questions saved as an extra parameter of this agent in the
        #  DB. Right now, the question is hard coded into the AAI interviewer agent.
        agent_class = (
            AAIInterviewer if db_agent.agent_type == AI_CHAT_AAI_INTERVIEWER_AGENT_TYPE else Agent
        )

        # Default language parameters assigned to the chat upon its creation can be overridden.
        llm_name = language_config.llm_name if self.llm_name is None else self.llm_name
        temperature = language_config.temperature if self.temperature is None else self.temperature
        top_p = language_config.top_p if self.top_p is None else self.top_p

        api_key = settings.open_ai_api_key if "gpt" in llm_name else settings.anthropic_api_key
        agent = agent_class(
            agent_id=db_agent.agent_id,
            agent_type=db_agent.agent_type,
            application_type=db_agent.application_type,
            agent_attributes=db_agent.attributes,
            prompt_builder=PromptBuilder(
                prompt_template_loader=PersistedAgentPromptTemplateLoader(
                    application_type=db_agent.application_type,
                    agent_type=db_agent.agent_type,
                    template_version=db_agent.prompt_template_version,
                    db=next(get_db()),
                ),
                placeholder_values=db_agent.attributes,
            ),
            llm_name=llm_name,
            temperature=temperature,
            top_p=top_p,
            chat_history_length=language_config.chat_history_length,
            token_callback=TokenCallback() if self.stream_tokens else None,
            memory_retrieval_capacity=memory_config.memory_retrieval_capacity,
            working_memory_token_capacity=memory_config.working_memory_capacity,
            remember_previous_sessions=memory_config.remember_previous_sessions,
            retrieve_chat_memories=memory_config.retrieve_chat_memories,
            memory_db=next(get_db()),
            memory_embedding_db=next(
                get_vector_db(
                    SentenceTransformerEmbeddings(
                        model_name=memory_config.memory_embedding_model_name
                    )
                )
            ),
            initial_world_state=WorldState(
                application_id=self.chat_id,
                session_id=self.session_id,
            ),
            context_window_capacity=memory_config.context_window_capacity,
            api_key=api_key,
        )

        agent.initialize_agent()
        self._messages_by_agent[db_agent.agent_id] = []
        if db_agent.agent_type == AI_CHAT_USER_AGENT_TYPE:
            self._user = agent
        else:
            self._coach = agent

    def create_new_session(self):
        """
        Creates a new chat session ID.
        """
        self.session_id = str(uuid.uuid4())

    def play(self, number_messages: int, progress_bar: Optional[tqdm] = None):
        """
        Play a conversation between a user and a coach agent.

        :param number_messages: number of messages exchanged between the agents.:
        :param progress_bar: an optional tqdm progress bar tha is updated at every message
            exchanged.
        """
        if self._user is None:
            raise Exception(
                "User agent is None. Cannot play the conversation. Make sure to call "
                "the `initialize_chat` function to initialize the chat and agents "
                "before playing the chat."
            )
        if self._coach is None:
            raise Exception(
                "Coach agent is None. Cannot play the conversation. Make sure to "
                "call the `initialize_chat` function to initialize the chat and "
                "agents before playing the chat."
            )

        max_number_messages = len(self.messages_) + number_messages
        self._run_conversation_loop(max_number_messages, progress_bar)

    def _run_conversation_loop(
        self, max_number_messages: int, progress_bar: Optional[tqdm] = None
    ):
        """
        Run a conversation loop until the maximum number of messages is achieved or one of the
        agents has no more messages to generate.

        :param max_number_messages: maximum number of messages to produce in the conversation.
        :param progress_bar: an optional tqdm progress bar tha is updated at every message
            exchanged.
        """

        if progress_bar is not None:
            progress_bar.total = max_number_messages
            progress_bar.set_description("Message")

        while len(self.messages_) < max_number_messages:
            if not self._play_next_turn():
                # The last agent to speak has no more messages to produce.
                break

            if progress_bar is not None:
                progress_bar.update()

            time.sleep(self.seconds_between_messages)

    def _play_next_turn(self) -> bool:
        """
        Plays the turn to generate a new message in the loop.

        :return: a flag indicating if a message was produced at the turn.
        """

        if self.before_turn_callback is not None:
            self.before_turn_callback()

        if self._last_speaker is None or self._last_speaker.agent_id == self._user.agent_id:
            source_agent = self._coach
            target_agent = self._user
        else:
            source_agent = self._user
            target_agent = self._coach

        started_at = datetime.now()
        last_message_from_target = (
            self._messages_by_agent[target_agent.agent_id][-1]
            if len(self._messages_by_agent[target_agent.agent_id]) > 0
            else None
        )
        world_state = WorldState(
            application_id=self.chat_id,
            session_id=self.session_id,
            last_text_message=last_message_from_target,
        )
        source_agent.perceive_world(world_state)
        message = source_agent.speak()
        finished_at = datetime.now()

        if self.after_turn_callback is not None:
            self.after_turn_callback(source_agent, target_agent, message)

        if message is None:
            return False

        # Persist to the database
        message_number = len(self.messages_) + 1
        self._persist_chat_message(
            prompt=message["prompt"],
            message=message["text"],
            message_number=message_number,
            agent_id=source_agent.agent_id,
            started_at=started_at,
            finished_at=finished_at,
            working_memory_items=source_agent.working_memory_items,
        )

        message = TextMessage(
            creation_timestamp=finished_at,
            content=message["text"],
            agent_name=source_agent.name,
            agent_type=source_agent.agent_type,
        )

        self.messages_.append(message)
        self._messages_by_agent[source_agent.agent_id].append(message)
        self._last_speaker = source_agent

        return True

    def _persist_chat_message(
        self,
        prompt: str,
        message: str,
        message_number: int,
        agent_id: int,
        started_at: datetime,
        finished_at: datetime,
        working_memory_items: list[Memory],
    ):
        """
        Saves the chat message and associated prompt to the database.

        :param prompt: prompt used by the agent's internal LLM to generate the message.
        :param message: content of the message.
        :param message_number: sequential number of the message produced by the agent.
        :param agent_id: ID of the agent that produced the message.
        :param started_at: when the message started to be formed.
        :param finished_at: when the message was fully delivered.
        :param working_memory_items: memory items in the agent's working memory at the time the
            chat message was generated.
        """
        db = next(get_db())
        app_message = ApplicationMessage(
            application_id=self.chat_id,
            session_id=self.session_id,
            message_id=str(uuid.uuid4()),
            message_number=message_number,
            agent_id=agent_id,
            started_at=started_at,
            finished_at=finished_at,
            message=message,
            prompt=prompt,
            flagged=False,
        )
        db.add(app_message)

        working_memories = []
        for memory in working_memory_items:
            working_memories.append(
                WorkingMemory(
                    message_id=app_message.message_id,
                    memory_id=memory.memory_id,
                    attributes={"relevance_score": memory.relevance_score},
                )
            )
        db.add_all(working_memories)
        db.commit()
        db.close()

    def to_data_frame(self) -> pd.DataFrame:
        """
        Exports messages exchanged to CSV.

        :return: pandas dataframe containing all messages produced by the agent.
        """
        header = [
            "chat_id",
            "chat_session",
            "agent_name",
            "agent_type",
            "message_number",
            "message_timestamp",
            "message_content",
        ]
        data = []
        for i, message in enumerate(self.messages_):
            data.append(
                [
                    self.chat_id,
                    self.session_id,
                    message.agent_name,
                    message.agent_type,
                    i + 1,
                    message.creation_timestamp,
                    message.content,
                ]
            )

        return pd.DataFrame(data, columns=header)

    @staticmethod
    def create(
        user_agent_id: int,
        user_memory_config: MemoryConfig,
        user_language_config: LanguageConfig,
        coach_agent_id: int,
        coach_memory_config: MemoryConfig,
        coach_language_config: LanguageConfig,
    ) -> int:
        """
        Creates a new chat between a user and a coach.

        :param user_agent_id: ID of the user agent.
        :param user_memory_config: configurations for the user agent's memory model.
        :param user_language_config: configurations for the user agent's language model.
        :param coach_agent_id: ID of the coach agent.
        :param coach_memory_config: configurations for the coach agent's memory model.
        :param coach_language_config: configurations for the coach agent's language model.
        :return: the ID of the newly created chat.
        """

        db = next(get_db())
        chat = Chat(
            online=True,
            flagged=False,
            attributes={
                user_agent_id: {
                    **user_memory_config.__dict__,
                    **user_language_config.__dict__,
                },
                coach_agent_id: {
                    **coach_memory_config.__dict__,
                    **coach_language_config.__dict__,
                },
            },
        )

        db.add(chat)
        db.flush()  # to update the chat's ID
        new_chat_id = chat.application_id

        user_chat = ApplicationAgent(application_id=new_chat_id, agent_id=user_agent_id)
        coach_chat = ApplicationAgent(application_id=new_chat_id, agent_id=coach_agent_id)
        db.add_all([user_chat, coach_chat])
        db.commit()
        db.close()

        return new_chat_id
