from typing import List

import streamlit as st
from sqlalchemy import select

from synthetic_agents.app.chat_config import LanguageConfig, MemoryConfig
from synthetic_agents.app.playable_chat import PlayableChat
from synthetic_agents.common.constants import AI_CHAT_USER_AGENT_TYPE
from synthetic_agents.database.config import get_db
from synthetic_agents.database.entity.agent import Agent
from synthetic_agents.database.entity.life_memory import LifeMemory as DBLifeMemory
from synthetic_agents.model.constants import (
    DEFAULT_CHAT_HISTORY_LENGTH,
    DEFAULT_CONTEXT_WINDOW_CAPACITY,
    DEFAULT_LLM_TEMPERATURE,
    DEFAULT_LLM_TOP_P,
    DEFAULT_MEMORY_RETRIEVAL_CAPACITY,
    DEFAULT_WORKING_MEMORY_TOKEN_CAPACITY,
    HUGGING_FACE_EMBEDDING_MODEL,
)
from synthetic_agents.model.entity.life_memory import LifeMemory
from synthetic_agents.webapp.common.constants import CHAT_APPLICATION_TYPE, LLMS
from synthetic_agents.webapp.common.functions import display_agent_info


class ChatBuilderComponent:
    """
    This class handles the creation of a chat between a user agent and a coach.
    """

    def __init__(self, component_key: str):
        """
        Creates the agent builder component.

        :param component_key: unique identifier for the component in a page.
        """
        self.component_key = component_key

        # Values saved within page loading and available to the next widgets to be rendered.
        # Not persisted through the session.
        self._user_agent = None
        self._coach_agent = None
        self._user_memory_config: MemoryConfig = None
        self._user_language_config: LanguageConfig = None
        self._coach_memory_config: MemoryConfig = None
        self._coach_language_config: LanguageConfig = None

    def create_component(self):
        """
        Creates all the widgets and logic of the chat builder component.
        """
        st.header("Agents", divider="gray")
        self._create_agent_selection_widgets()

        st.header("Memory", divider="gray")
        col_user, col_coach = st.columns(2)
        with col_user:
            self._user_memory_config = self._create_memory_widgets(area_key="user")
        with col_coach:
            self._coach_memory_config = self._create_memory_widgets(area_key="coach")

        st.header("Language", divider="gray")
        col_user, col_coach = st.columns(2)
        with col_user:
            self._user_language_config = self._create_language_widgets(area_key="user")
        with col_coach:
            self._coach_language_config = self._create_language_widgets(area_key="coach")

        st.divider()
        if self._user_agent and self._coach_agent:
            if st.button("Create Chat"):
                new_chat_id = PlayableChat.create(
                    user_agent_id=self._user_agent.agent_id,
                    user_memory_config=self._user_memory_config,
                    user_language_config=self._user_language_config,
                    coach_agent_id=self._coach_agent.agent_id,
                    coach_memory_config=self._coach_memory_config,
                    coach_language_config=self._coach_language_config,
                )
                st.success(f"Chat {new_chat_id} created successfully!")

    def _create_agent_selection_widgets(self):
        """
        Creates widgets for user and coach selection.
        """
        if st.button("Refresh Agent List", key=f"{self.component_key}_refresh_button"):
            ChatBuilderComponent._read_available_user_agents.clear()

        def format_agent_display(agent: Agent) -> str:
            """
            Formats agent's info to display in the dropdown widget.

            :param agent: agent.
            :return: formatted agent's info.
            """
            format_template = "id: {id}, name: {name}, age: {age}, gender: {gender}"
            return format_template.format(
                id=agent.agent_id,
                name=agent.attributes.get("name", "?"),
                age=agent.attributes.get("age", "?"),
                gender=agent.attributes.get("gender", "?"),
            )

        col1, col2 = st.columns(2)
        with col1:
            self._user_agent = st.selectbox(
                "User",
                key=f"{self.component_key}_user_selector",
                options=ChatBuilderComponent._read_available_user_agents(),
                format_func=format_agent_display,
            )
            if self._user_agent:
                display_agent_info(
                    attributes=self._user_agent.attributes,
                    life_facts=ChatBuilderComponent._read_agent_life_facts(
                        self._user_agent.agent_id
                    ),
                )

        with col2:
            coach_types = ChatBuilderComponent._read_available_interviewer_agents()
            self._coach_agent = st.selectbox(
                "Interviewer",
                key=f"{self.component_key}_coach_selector",
                options=coach_types,
                format_func=format_agent_display,
            )
            if self._coach_agent:
                display_agent_info(
                    attributes=self._coach_agent.attributes,
                    life_facts=ChatBuilderComponent._read_agent_life_facts(
                        self._coach_agent.agent_id
                    ),
                )

    @staticmethod
    @st.cache_data
    def _read_available_user_agents() -> List[Agent]:
        """
        Reads chat user agents from the database.

        :return: list of agents.
        """
        db = next(get_db())
        records = db.execute(
            select(Agent)
            .where(
                Agent.application_type == CHAT_APPLICATION_TYPE,
                Agent.agent_type == AI_CHAT_USER_AGENT_TYPE,
            )
            .order_by(Agent.agent_id)
        ).all()
        agents = [r[0] for r in records]
        db.close()

        return agents

    @staticmethod
    @st.cache_data
    def _read_available_interviewer_agents() -> List[Agent]:
        """
        Reads chat interviwer agents (nny type that is not an AI user) from the database.

        :return: list of agents.
        """
        db = next(get_db())
        records = db.execute(
            select(Agent)
            .where(
                Agent.application_type == CHAT_APPLICATION_TYPE,
                Agent.agent_type != AI_CHAT_USER_AGENT_TYPE,
            )
            .order_by(Agent.agent_id)
        ).all()
        agents = [r[0] for r in records]
        db.close()

        return agents

    @staticmethod
    @st.cache_data
    def _read_agent_life_facts(agent_id: int) -> List[LifeMemory]:
        """
        Reads life facts pertained to an agent from the database.

        :param agent_id: the id of the agent related to the life facts.
        :return: a list of life fact memories.
        """
        db = next(get_db())
        records = db.execute(
            select(DBLifeMemory)
            .where(DBLifeMemory.agent_id == agent_id)
            .order_by(DBLifeMemory.creation_timestamp)
        ).all()
        life_facts = [LifeMemory.from_persistent_object(r[0]) for r in records]
        db.close()

        return life_facts

    def _create_memory_widgets(self, area_key: str):
        """
        Creates widgets to collect memory-related information required for to configure a chat.

        :param area_key: unique identifier for the widgets created in the area such that this
            function can be called multiple times for different agents.
        """
        remember_previous_sessions = st.toggle(
            "Use memories from previous sessions",
            key=f"{self.component_key}_{area_key}_remember_previous_sessions_toggle",
        )
        retrieve_chat_memories = st.toggle(
            "Retrieve chat memories",
            key=f"{self.component_key}_{area_key}_retrieve_chat_memories_toggle",
            value=True,
        )
        retrieval_capacity = st.number_input(
            "Retrieval Capacity (max number of memories retrieved)",
            key=f"{self.component_key}_{area_key}_retrieval_capacity_input",
            value=DEFAULT_MEMORY_RETRIEVAL_CAPACITY,
            min_value=0,
            max_value=100,
        )
        working_memory_capacity = st.number_input(
            "Working Memory Capacity (max number of tokens)",
            key=f"{self.component_key}_{area_key}_wm_capacity_input",
            value=DEFAULT_WORKING_MEMORY_TOKEN_CAPACITY,
            min_value=0,
            max_value=2000,
        )
        context_window_capacity = st.number_input(
            "Context Window Capacity (number of messages to use as context for memory retrieval)",
            key=f"{self.component_key}_{area_key}_attention_window_size_input",
            value=DEFAULT_CONTEXT_WINDOW_CAPACITY,
            min_value=0,
            max_value=100,
        )

        return MemoryConfig(
            remember_previous_sessions=remember_previous_sessions,
            retrieve_chat_memories=retrieve_chat_memories,
            memory_retrieval_capacity=retrieval_capacity,
            working_memory_capacity=working_memory_capacity,
            context_window_capacity=context_window_capacity,
            # Fixed embedding model for now.
            memory_embedding_model_name=HUGGING_FACE_EMBEDDING_MODEL,
        )

    def _create_language_widgets(self, area_key: str):
        """
        Creates widgets to collect language-related information required for to configure a chat.

        :param area_key: unique identifier for the widgets created in the area such that this
            function can be called multiple times for different agents.
        """
        model = st.selectbox(
            "Model", key=f"{self.component_key}_{area_key}_llm_selector", options=LLMS
        )

        chat_history_length = st.number_input(
            "Chat History Length (number of messages to keep as history)",
            key=f"{self.component_key}_{area_key}_chat_history_length_input",
            value=DEFAULT_CHAT_HISTORY_LENGTH,
            min_value=0,
            max_value=100,
        )

        temperature = st.slider(
            "Temperature",
            key=f"{self.component_key}_{area_key}_llm_temperature_input",
            value=float(DEFAULT_LLM_TEMPERATURE),
            min_value=0.0,
            max_value=1.0,
        )

        top_p = st.slider(
            "Top-p",
            key=f"{self.component_key}_{area_key}_llm_top_p_input",
            value=float(DEFAULT_LLM_TOP_P),
            min_value=0.0,
            max_value=1.0,
        )

        return LanguageConfig(
            temperature=temperature,
            top_p=top_p,
            llm_name=model,
            chat_history_length=chat_history_length,
        )
