import time
import uuid
from datetime import datetime

import pandas as pd
import streamlit as st
from langchain_community.embeddings import SentenceTransformerEmbeddings
from sqlalchemy import select

from synthetic_agents.app.chat_config import LanguageConfig, MemoryConfig
from synthetic_agents.callback.token import TokenCallback
from synthetic_agents.common.config import settings
from synthetic_agents.common.constants import (
    AI_CHAT_AAI_INTERVIEWER_AGENT_TYPE,
    AI_CHAT_USER_AGENT_TYPE,
    HUMAN_AGENT_TYPE,
)
from synthetic_agents.database.config import get_db, get_vector_db
from synthetic_agents.database.entity.application_message import ApplicationMessage
from synthetic_agents.database.entity.chat import Chat
from synthetic_agents.database.entity.working_memory import WorkingMemory
from synthetic_agents.model.aai_interviewer import AAIInterviewer
from synthetic_agents.model.agent import Agent
from synthetic_agents.model.entity.world import TextMessage, WorldState
from synthetic_agents.model.memory import Memory
from synthetic_agents.prompt.builder import PromptBuilder
from synthetic_agents.prompt.loader import PersistedAgentPromptTemplateLoader
from synthetic_agents.webapp.common.constants import SECONDS_BETWEEN_MESSAGES
from synthetic_agents.webapp.session.chat import ChatSession


class ChatComponent:
    """
    This class handles the progress of an online chat between a user agent and a coach.
    """

    def __init__(self, component_key: str):
        """
        Creates the agent builder component.

        :param component_key: unique identifier for the component in a page.
        """
        self.component_key = component_key

        # Values saved within page loading and available to the next widgets to be rendered.
        # Not persisted through the session.
        self._chat = None
        self._max_automated_messages = None

    def create_component(self):
        """
        Creates all the widgets and logic of the chat component.
        """

        if st.button("New Session", key=f"{self.component_key}_new_session_button"):
            ChatSession.reset()
        st.write(f"**Chat Session:** {ChatSession.get_session_id()}")

        self._chat = st.selectbox(
            "Chat*",
            key=f"{self.component_key}_chat_selector",
            options=ChatComponent._read_available_chats(),
            format_func=lambda chat: chat.application_id,
        )
        if st.button("Refresh Chat list", key=f"{self.component_key}_refresh_button"):
            ChatComponent._read_available_chats.clear()

        if self._chat is None:
            return

        # The logger will be created everytime the page reloads. We could encapsulate it in a
        # session object to prevent that but the overload to recreate a logger is minimal and this
        # only happens because of the way streamlit works anyway, which is by re-rendering the
        # whole page at every widget action.
        # Right now we don't have a logger fully implemented yet. Uncomment the line below when we
        # do.
        # self._logger = DatabaseLogWriter(self._chat.application_id, get_database_engine())

        with st.expander("Details", expanded=False):
            if self._chat:
                st.write("**Chat Attributes**")
                st.json(self._chat.attributes, expanded=False)

                with st.spinner("Loading agents..."):
                    self._initialize_agents()

                    st.write("**User's Attributes**")
                    agent_attributes = ChatSession.get_user().agent_attributes.copy()
                    agent_attributes["id"] = ChatSession.get_user().agent_id
                    st.json(agent_attributes, expanded=False)

                    st.write("**Coach's Attributes**")
                    agent_attributes = ChatSession.get_coach().agent_attributes.copy()
                    agent_attributes["id"] = ChatSession.get_coach().agent_id
                    st.json(agent_attributes, expanded=False)

        with st.sidebar:
            self._create_chat_control_widgets()

        st.header("Chat", divider="gray")
        self._create_chat_widgets()

    @staticmethod
    @st.cache_data
    def _read_available_chats() -> list[Chat]:
        """
        Reads chats from the database.

        :return: list of chats.
        """
        db = next(get_db())
        records = db.execute(select(Chat).order_by(Chat.application_id)).all()
        chats = [r[0] for r in records]
        db.close()

        return chats

    def _initialize_agents(self):
        """
        Creates user and coach agents from the chat attributes and store in the session.
        """
        for db_agent in self._chat.agents:
            if db_agent.agent_type == AI_CHAT_USER_AGENT_TYPE:
                if (
                    ChatSession.get_user() is not None
                    and ChatSession.get_user().agent_id == db_agent.agent_id
                ):
                    # User agent has already been initialized and it's saved in the chat session.
                    continue
            else:
                if (
                    ChatSession.get_coach() is not None
                    and ChatSession.get_coach().agent_id == db_agent.agent_id
                ):
                    # Coach agent has already been initialized and it's saved in the chat session.
                    continue

            chat_attributes = self._chat.attributes[str(db_agent.agent_id)]
            memory_config = MemoryConfig(
                **{key: chat_attributes[key] for key in MemoryConfig.__annotations__}
            )
            language_config = LanguageConfig(
                **{key: chat_attributes[key] for key in LanguageConfig.__annotations__}
            )
            # TODO: Change this is the future to be more flexible. Maybe by having a deterministic
            #  agent and the series of questions saved as an extra parameter of this agent in the
            #  DB. Right now, the question is hard coded into the AAI interviewer agent.
            agent_class = (
                AAIInterviewer
                if db_agent.agent_type == AI_CHAT_AAI_INTERVIEWER_AGENT_TYPE
                else Agent
            )

            api_key = (
                settings.open_ai_api_key
                if "gpt" in language_config.llm_name
                else settings.anthropic_api_key
            )
            agent = agent_class(
                agent_id=db_agent.agent_id,
                agent_type=db_agent.agent_type,
                application_type=db_agent.application_type,
                agent_attributes=db_agent.attributes,
                prompt_builder=PromptBuilder(
                    prompt_template_loader=PersistedAgentPromptTemplateLoader(
                        application_type=db_agent.application_type,
                        agent_type=db_agent.agent_type,
                        template_version=db_agent.prompt_template_version,
                        db=next(get_db()),
                    ),
                    placeholder_values=db_agent.attributes,
                ),
                llm_name=language_config.llm_name,
                temperature=language_config.temperature,
                top_p=language_config.top_p,
                chat_history_length=language_config.chat_history_length,
                token_callback=TokenCallback(),
                memory_retrieval_capacity=memory_config.memory_retrieval_capacity,
                working_memory_token_capacity=memory_config.working_memory_capacity,
                context_window_capacity=memory_config.context_window_capacity,
                remember_previous_sessions=memory_config.remember_previous_sessions,
                retrieve_chat_memories=memory_config.retrieve_chat_memories,
                memory_db=next(get_db()),
                memory_embedding_db=next(
                    get_vector_db(
                        SentenceTransformerEmbeddings(
                            model_name=memory_config.memory_embedding_model_name
                        )
                    )
                ),
                initial_world_state=WorldState(
                    application_id=self._chat.application_id,
                    session_id=ChatSession.get_session_id(),
                ),
                api_key=api_key,
            )

            agent.initialize_agent()
            if agent.agent_type == AI_CHAT_USER_AGENT_TYPE:
                ChatSession.store_user(agent)
            else:
                ChatSession.store_coach(agent)

    def _create_chat_control_widgets(self):
        """
        Creates widgets for controlling the chat.
        """

        if ChatSession.get_coach().agent_type == HUMAN_AGENT_TYPE:
            st.download_button(
                "Export to CSV",
                self._messages_to_csv(),
                f"chat_{self._chat.application_id}.csv",
                "text/csv",
                key=f"{self.component_key}_export_csv_button",
                disabled=not ChatSession.has_messages(),
            )
        else:
            self._max_automated_messages = st.slider(
                "Maximum Number of Messages", min_value=2, max_value=100
            )
            col1, col2 = st.columns(2)
            with col1:
                if ChatSession.is_automated_chat_in_progress():
                    stop_chat = st.button(
                        "Stop Chat",
                        key=f"{self.component_key}_stop_chat_button",
                    )
                    if stop_chat:
                        ChatSession.toggle_automated_chat(chat_on=False)
                        # Rerun to update button label.
                        st.rerun()
                else:
                    start_chat = st.button(
                        "Start Chat",
                        key=f"{self.component_key}_start_chat_button",
                        disabled=len(ChatSession.get_messages()) >= self._max_automated_messages,
                    )
                    if start_chat:
                        ChatSession.toggle_automated_chat(chat_on=True)
                        # Rerun to update button label.
                        st.rerun()
            with col2:
                st.download_button(
                    "Export to CSV",
                    self._messages_to_csv(),
                    f"chat_{self._chat.application_id}.csv",
                    "text/csv",
                    key=f"{self.component_key}_export_csv_button",
                    disabled=(
                        not ChatSession.has_messages()
                        or ChatSession.is_automated_chat_in_progress()
                    ),
                )

    def _create_chat_widgets(self):
        """
        Creates chat widgets on the screen and produces the next message from a source agent to a
        target agent if the chat is in automated mode. Here we also persist working memories and
        target agent if the chat is in automated mode. Here we also persist working memories and
        chat messages to the database.
        """
        # Display chat messages from history on app rerun
        ChatComponent._display_all_messages()

        # Decides who speaks next. For now, the coach always initiates the conversation.
        user = ChatSession.get_user()
        coach = ChatSession.get_coach()
        last_speaker_type = ChatSession.get_last_speaker_type()
        if last_speaker_type is None or last_speaker_type == AI_CHAT_USER_AGENT_TYPE:
            source_agent = coach
            target_agent = user
        else:
            source_agent = user
            target_agent = coach

        message_number = len(ChatSession.get_messages()) + 1

        if source_agent.agent_type == HUMAN_AGENT_TYPE:
            # Input for manual entry
            if ChatSession.get_current_manual_message_start_time() is None:
                ChatSession.store_current_manual_message_start_time(datetime.now())

            message_content = st.chat_input(
                f"{source_agent.name}, enter your message to {target_agent.name}."
            )
            if message_content:
                with st.chat_message(source_agent.name):
                    st.write(message_content)

                started_at = ChatSession.get_current_manual_message_start_time()
                finished_at = datetime.now()

                self._persist_chat_message(
                    prompt="",
                    message=message_content,
                    message_number=message_number,
                    agent_id=source_agent.agent_id,
                    started_at=started_at,
                    finished_at=finished_at,
                    working_memory_items=source_agent.working_memory_items,
                )
                ChatSession.store_current_manual_message_start_time(None)

                message = TextMessage(
                    creation_timestamp=finished_at,
                    content=message_content,
                    agent_name=source_agent.name,
                    agent_type=source_agent.agent_type,
                )
                ChatSession.store_chat_message(agent=source_agent, message=message)
                st.rerun()

        elif (
            target_agent.agent_type == HUMAN_AGENT_TYPE
            or ChatSession.is_automated_chat_in_progress()
        ):
            with st.chat_message(source_agent.name):
                # Create an empty slate for writing the message as it is produced by the agent in a
                # call to agent.speak().
                message_placeholder = st.empty()
                started_at = datetime.now()
                message = self._generate_automated_message(
                    message_placeholder=message_placeholder,
                    source_agent=source_agent,
                    target_agent=target_agent,
                )

                if message is None:
                    ChatSession.toggle_automated_chat(False)
                else:
                    finished_at = datetime.now()
                    message_content = message["text"]
                    message_placeholder.write(message_content)

                    # Persist to the database
                    self._persist_chat_message(
                        prompt=message["prompt"],
                        message=message_content,
                        message_number=message_number,
                        agent_id=source_agent.agent_id,
                        started_at=started_at,
                        finished_at=finished_at,
                        working_memory_items=source_agent.working_memory_items,
                    )

                    message = TextMessage(
                        creation_timestamp=finished_at,
                        content=message_content,
                        agent_name=source_agent.name,
                        agent_type=source_agent.agent_type,
                    )
                    ChatSession.store_chat_message(agent=source_agent, message=message)

                    if len(ChatSession.get_messages()) == self._max_automated_messages:
                        # Stop criteria
                        ChatSession.toggle_automated_chat(False)

            # Soft page refresh so the automated chat can keep going.
            time.sleep(SECONDS_BETWEEN_MESSAGES)
            st.rerun()

    @staticmethod
    def _display_all_messages():
        """
        Re-renders all historic chat messages up to the moment on the screen. This is necessary
        because after each message is delivered, the webapp is refreshed and the web components
        re-rendered on the screen. So this step will make sure the chat messages are kept on the
        screen at all times.
        """
        for i, message in enumerate(ChatSession.get_messages()):
            with st.chat_message(message.agent_name, avatar=None):
                # Messages from the speaker who begins the conversations are always bold faced.
                st.write(message.content if i % 2 != 0 else f"**{message.content}**")

    def _generate_automated_message(
        self, message_placeholder: st.container, source_agent: Agent, target_agent: Agent
    ) -> dict[str, str]:
        """
        Asks a source agent to perceive the world and produce the next message to be delivered.

        :param message_placeholder: space in the chat to write tokens as they are produced by the
            agent to simulate they are typing in real time.
        :param source_agent: agent producing the next message.
        :param target_agent: agent the source agent is responding to.
        :return: a dictionary with "prompt" and "text" keys containing the prompt used in the
            internal agent's LLM and content of the message produced respectively.
        """
        source_agent.token_callback.reset(message_placeholder)
        last_target_agent_message = ChatSession.get_last_message_from_agent(target_agent)
        world_state = WorldState(
            application_id=self._chat.application_id,
            session_id=ChatSession.get_session_id(),
            last_text_message=last_target_agent_message,
        )
        source_agent.perceive_world(world_state)
        return source_agent.speak()

    def _persist_chat_message(
        self,
        prompt: str,
        message: str,
        message_number: int,
        agent_id: int,
        started_at: datetime,
        finished_at: datetime,
        working_memory_items: list[Memory],
    ):
        """
        Saves the chat message and associated prompt to the database.

        :param prompt: prompt used by the agent's internal LLM to generate the message.
        :param message: content of the message.
        :param message_number: sequential number of the message produced by the agent.
        :param agent_id: ID of the agent that produced the message.
        :param started_at: when the message started to be formed.
        :param finished_at: when the message was fully delivered.
        :param working_memory_items: memory items in the agent's working memory at the time the
            chat message was generated.
        """
        db = next(get_db())
        app_message = ApplicationMessage(
            application_id=self._chat.application_id,
            session_id=ChatSession.get_session_id(),
            message_id=str(uuid.uuid4()),
            message_number=message_number,
            agent_id=agent_id,
            started_at=started_at,
            finished_at=finished_at,
            message=message,
            prompt=prompt,
            flagged=False,
        )
        db.add(app_message)

        working_memories = []
        for memory in working_memory_items:
            working_memories.append(
                WorkingMemory(
                    message_id=app_message.message_id,
                    memory_id=memory.memory_id,
                    attributes={"relevance_score": memory.relevance_score},
                )
            )
        db.add_all(working_memories)
        db.commit()
        db.close()

    def _messages_to_csv(self) -> str:
        """
        Transforms messages to a string containing data in a csv format.

        :return: chat messages in a csv format.
        """
        header = [
            "chat_id",
            "chat_session",
            "agent_name",
            "agent_type",
            "message_timestamp",
            "message_content",
        ]
        data = []
        for message in ChatSession.get_messages():
            data.append(
                [
                    self._chat.application_id,
                    ChatSession.get_session_id(),
                    message.agent_name,
                    "user" if message.agent_type == AI_CHAT_USER_AGENT_TYPE else "coach",
                    message.creation_timestamp,
                    message.content,
                ]
            )

        df = pd.DataFrame(data, columns=header)
        return df.to_csv(index=False).encode("utf-8")
