from __future__ import annotations

import uuid
from copy import deepcopy
from datetime import datetime
from typing import List, Optional

from langchain.docstore.document import Document
from langchain.vectorstores import VectorStore
from sqlalchemy import func, select
from sqlalchemy.orm import Session

from synthetic_agents.common.config import MESSAGE_DATETIME_FORMAT
from synthetic_agents.database.entity.memory import Memory as DBMemory
from synthetic_agents.model.constants import (
    DEFAULT_MEMORY_RETRIEVAL_CAPACITY,
    DEFAULT_OPENAI_CHAT_MODEL_NAME,
    DEFAULT_WORKING_MEMORY_TOKEN_CAPACITY,
)
from synthetic_agents.model.entity.chat_memory import CHAT_MEMORY_TYPE, ChatMemory
from synthetic_agents.model.entity.life_memory import LIFE_MEMORY_TYPE, LifeMemory
from synthetic_agents.model.entity.memory import Memory


class MemoryModel:
    """
    This class represents a memory model. It allows memory retrieval with RAG and ranking.
    """

    def __init__(
        self,
        agent_id: int,
        memory_db: Session,
        embedding_db: VectorStore,
        retrieval_capacity: int = DEFAULT_MEMORY_RETRIEVAL_CAPACITY,
        working_memory_token_capacity: int = DEFAULT_WORKING_MEMORY_TOKEN_CAPACITY,
        working_memory_buffer: Optional[List[Memory]] = None,
        llm_name: str = DEFAULT_OPENAI_CHAT_MODEL_NAME,
        remember_previous_sessions: bool = True,
        retrieve_chat_memories: bool = True,
    ):
        """
        Creates a memory model.

        :param agent_id: ID of the agent associated with the memory model.
        :param memory_db: database session that connects to a relational database where
            memories must be stored.
        :param embedding_db: vector store instance for memory embedding persistence.
        :param retrieval_capacity: the maximum number of messages to be retrieved from the
            long-term storage.
        :param working_memory_token_capacity: the maximum number of tokens to be consumed by the
            working memory buffer.
        :param working_memory_buffer: a list of initial working memories.
        :param llm_name: name of the model that will use the memories as part of the prompt.
            This is used to calculate the number of tokens consumed by each memory item such that
            we don't exceed the working memory buffer.
        :param remember_previous_sessions: the memory model can retrieve memories generated in any
            application session.
        :param retrieve_chat_memories: whether to retrieve the chat memories or only life memories.
        """

        self.agent_id = agent_id
        self.memory_db = memory_db
        self.embedding_db = embedding_db
        self.retrieval_capacity = retrieval_capacity
        self.working_memory_token_capacity = working_memory_token_capacity
        self.working_memory_buffer = [] if working_memory_buffer is None else working_memory_buffer
        self.llm_name = llm_name
        self.remember_previous_sessions = remember_previous_sessions
        self.retrieve_chat_memories = retrieve_chat_memories

    @property
    def num_embeddings(self) -> int:
        """
        Gets the number of documents in the embeddings database for the agent.

        :raise Exception: if documents could not be read from the memory embeddings database.
        :return: number of documents.
        """
        filter_clause = {"agent_id": self.agent_id}
        try:
            docs = self.embedding_db.get(where=filter_clause, include=["documents"])["documents"]
        except Exception as ex:
            raise Exception(f"Could not read from the memory embeddings database. {ex}.")

        return len(docs)

    @property
    def num_memories(self) -> int:
        """
        Gets the number of memories of the agent.

        :raise ValueError: if the memory_database_engine variable is undefined.
        :raise Exception: if memories could not be read from the database.
        :return: number of memories.
        """
        try:
            num_records = self.memory_db.scalar(
                select(func.count())
                .select_from(DBMemory)
                .where(DBMemory.agent_id == self.agent_id)
            )
            return num_records
        except Exception as ex:
            raise Exception(f"Could not read from {DBMemory.__tablename__} table. {ex}.")

    @property
    def working_memory_size(self) -> int:
        """
        Gets the number of tokens in the working memory.
        :return number of tokens in the working memory.
        """
        num_tokens = 0
        for memory in self.working_memory_buffer:
            num_tokens += memory.estimate_num_tokens(self.llm_name)

        return num_tokens

    def refresh_working_memory(
        self,
        context: str,
        application_id: Optional[int] = None,
        session_id: Optional[str] = None,
    ):
        """
        Repopulates the working memory with memories retrieved from the long-term memory and ranked
        according to their recency, importance and relevance to the current context.

        :param context: context retrieved memories are relevant to.
        :param application_id: ID of the application where retrieved memories were generated.
        :param session_id: ID of the application session where retrieved memories were generated.
        """
        memories = self._retrieve(
            context=context, application_id=application_id, session_id=session_id
        )
        # Ranking
        memories.sort(key=lambda m: m.relevance_score, reverse=True)
        self.working_memory_buffer = []
        self._fill_working_memory_to_capacity(memories)

    def _retrieve(
        self, context: str, application_id: Optional[int], session_id: Optional[str]
    ) -> List[Memory]:
        """
        Gets a list of memories from the long-term storage.

        :param context: query to be used to look for relevant memories in the embedding space.
        :param application_id: ID of the application where retrieved memories were generated.
        :param session_id: ID of the application session where retrieved memories were generated.
        :return: a list of relevant memories
        """
        if self.retrieve_chat_memories:
            if self.remember_previous_sessions:
                filter_options = {"agent_id": self.agent_id}
            else:
                # Get any life memory or memories from a particular application and session from a
                # specific agent.
                filter_options = {
                    "$and": [
                        {"agent_id": self.agent_id},
                        {
                            "$or": [
                                {"memory_type": LIFE_MEMORY_TYPE},
                                {
                                    "$and": [
                                        {"application_id": application_id},
                                        {"session_id": session_id},
                                    ]
                                },
                            ]
                        },
                    ]
                }
        else:
            filter_options = {
                "$and": [
                    {"agent_id": self.agent_id},
                    {"memory_type": LIFE_MEMORY_TYPE},
                ]
            }

        documents = self.embedding_db.similarity_search_with_score(
            query=context, k=self.retrieval_capacity, filter=filter_options
        )

        memories = []
        for document, similarity_score in documents:
            memory_type = document.metadata["memory_type"]
            memory_id = document.metadata["memory_id"]
            creation_timestamp = datetime.strptime(
                document.metadata["creation_timestamp"], MESSAGE_DATETIME_FORMAT
            )
            if memory_type == LIFE_MEMORY_TYPE:
                memory = LifeMemory(
                    memory_id=memory_id,
                    creation_timestamp=creation_timestamp,
                    content=document.page_content,
                    similarity_score=similarity_score,
                )
            elif memory_type == CHAT_MEMORY_TYPE:
                memory = ChatMemory(
                    memory_id=memory_id,
                    creation_timestamp=creation_timestamp,
                    content=document.page_content,
                    similarity_score=similarity_score,
                    chat_id=document.metadata["application_id"],
                    session_id=document.metadata["session_id"],
                )
            else:
                raise ValueError(f"Unsupported memory type retrieved ({memory_type}).")

            memories.append(memory)

        return memories

    def _fill_working_memory_to_capacity(self, memories: List[Memory]):
        """
        Adds memories in order to the working memory buffer up to its capacity.

        :param memories: memories to be added to the working memory.
        """
        buffer_size = 0
        for memory in memories:
            num_tokens = memory.estimate_num_tokens(self.llm_name)
            if num_tokens + buffer_size > self.working_memory_token_capacity:
                break

            buffer_size += num_tokens
            self.working_memory_buffer.append(memory)

    def persist(self, memories: List[Memory]):
        """
        Persists memories and embeddings.

        :param memories: memories with prefilled IDs to be persisted.
        """
        if len(memories) == 0:
            return

        db_memories = [m.to_persistent_object(self.agent_id) for m in memories]

        try:
            self.memory_db.add_all(db_memories)
            self.memory_db.commit()

            # TODO: with Chroma, we could not find an equivalent to COMMIT, so once the code below
            #  is executed, memories will persist in the vector store. That's why I placed in the
            #  end of the flow, so that embeddings are only persisted if the memories could be
            #  persisted. However, this is not the best solution. We need some sort of rollback
            #  for vector db as well.
            self._persist_embeddings(memories)
        except Exception as ex:
            raise Exception(f"Could not persist memories. {ex}.")

    def _persist_embeddings(self, memories: List[Memory]):
        """
        Adds a list of memory embeddings to the long-term storage for Retrieval-Augmented
        Generation (RAG).

        :raise Exception: if memory embeddings could not be persisted to the database.
        """

        documents = []
        ids = []
        for memory in memories:
            # Persist the memory in the vector DB for
            metadata = memory.metadata
            metadata["agent_id"] = self.agent_id
            documents.append(Document(page_content=memory.content, metadata=metadata))

            ids.append(self._generate_new_document_id())

        try:
            self.embedding_db.add_documents(documents, ids=ids)
        except Exception as ex:
            raise Exception(f"Could not write to the memory embeddings database. {ex}.")

    def _generate_new_document_id(self) -> str:
        """
        Generates a new document ID that does not clash with any existing document ID in the
        long-term storage.

        :raise Exception: if memory embeddings could not be read from the database.
        :return generated document ID.
        """
        # Generate a random document ID. The document ID does not need to be traceable as it
        # won't be used for filtering. While extremely unlikely, we still check for duplicity
        # in the DB before assigning an ID to a document.
        while True:
            document_id = str(uuid.uuid4())
            try:
                docs = self.embedding_db.get(ids=[document_id])["ids"]
                if len(docs) == 0:
                    return document_id
            except Exception as ex:
                raise Exception(f"Could not read from the memory embeddings database. {ex}")

    def fill_memory_ids(self, memories: List[Memory]) -> List[Memory]:
        """
        Returns a copy of a list of memories with IDs filled such that they do not clash with any
        existing memory ID in the relational database.

        :raise ValueError: if the memory_database_engine variable is undefined.
        :raise Exception: if memories could not be read from the database.
        :return list of memories filled with valid ids.
        """
        memories_with_ids = []
        for memory in memories:
            # Generate a random memory ID. The memory ID does not need to be traceable as it
            # won't be used for filtering. While extremely unlikely, we still check for duplicity
            # in the DB before assigning an ID to a memory.
            while True:
                memory_id = str(uuid.uuid4())
                try:
                    num_records = self.memory_db.scalar(
                        select(func.count())
                        .select_from(DBMemory)
                        .where(DBMemory.agent_id == self.agent_id, DBMemory.memory_id == memory_id)
                    )
                except Exception as ex:
                    raise Exception(f"Could not read from {DBMemory.__tablename__} table. {ex}.")
                if num_records == 0:
                    memory_with_id = deepcopy(memory)
                    memory_with_id.memory_id = memory_id
                    memories_with_ids.append(memory_with_id)
                    break

        return memories_with_ids
