from datetime import datetime
from unittest import TestCase

import chromadb
from langchain.vectorstores import Chroma
from langchain_community.embeddings import SentenceTransformerEmbeddings
from sqlalchemy import create_engine
from sqlalchemy.orm import Session

from synthetic_agents.common.config import MESSAGE_DATETIME_FORMAT
from synthetic_agents.database.functions import create_all_tables
from synthetic_agents.model.constants import HUGGING_FACE_EMBEDDING_MODEL
from synthetic_agents.model.entity.chat_memory import ChatMemory
from synthetic_agents.model.entity.life_memory import LifeMemory
from synthetic_agents.model.memory import MemoryModel


class TestMemoryModel(TestCase):
    def setUp(self) -> None:
        """
        Defines a set of initial memories and instantiates an in-memory vector DB to be used by
        the test cases.
        """
        self.initial_memories = [
            LifeMemory(
                memory_id=1,
                creation_timestamp=datetime.strptime(
                    "2013-04-09 02:50:00", MESSAGE_DATETIME_FORMAT
                ),
                content="I graduated from college with honors",
            ),
            LifeMemory(
                memory_id=2,
                creation_timestamp=datetime.strptime(
                    "2016-08-13 16:47:12", MESSAGE_DATETIME_FORMAT
                ),
                content="I was promoted to a new position at work",
            ),
            LifeMemory(
                memory_id=3,
                creation_timestamp=datetime.strptime(
                    "2018-10-04 10:17:15", MESSAGE_DATETIME_FORMAT
                ),
                content="I got a new car",
            ),
            LifeMemory(
                memory_id=4,
                creation_timestamp=datetime.strptime(
                    "2021-04-20 22:20:53", MESSAGE_DATETIME_FORMAT
                ),
                content="I was able to pay off all my student loans",
            ),
            LifeMemory(
                memory_id=5,
                creation_timestamp=datetime.strptime(
                    "2022-06-19 14:43:00", MESSAGE_DATETIME_FORMAT
                ),
                content="I won a prestigious award for my work in the field",
            ),
            LifeMemory(
                memory_id=6,
                creation_timestamp=datetime.strptime(
                    "2011-07-17 18:32:00", MESSAGE_DATETIME_FORMAT
                ),
                content="I got into a car accident and had to pay a lot of money to repair it",
            ),
            LifeMemory(
                memory_id=7,
                creation_timestamp=datetime.strptime(
                    "2014-10-08 22:54:20", MESSAGE_DATETIME_FORMAT
                ),
                content="I was passed up for a promotion that I had been working towards",
            ),
            LifeMemory(
                memory_id=8,
                creation_timestamp=datetime.strptime(
                    "2017-02-15 11:21:55", MESSAGE_DATETIME_FORMAT
                ),
                content="I had to take a pay cut in order to keep my job",
            ),
        ]

        self.embedding_db_client = chromadb.Client()  # in-memory
        self.embedding_db_client.get_or_create_collection("test")
        self.embedding_db = Chroma(
            client=self.embedding_db_client,
            collection_name="test",
            embedding_function=SentenceTransformerEmbeddings(
                model_name=HUGGING_FACE_EMBEDDING_MODEL
            ),
        )

        engine = create_engine("sqlite://")
        create_all_tables(engine)
        self.db = Session(engine)  # in-memory

    def tearDown(self) -> None:
        """
        Deletes the test collection in the vector DB such a change in the DB made by a test case
        doesn't affect another.
        """
        self.embedding_db_client.delete_collection("test")
        self.db.close()

    def test_persistence(self):
        """
        Checks that the number of memories in the database is consistent with the number of memory
        items that were inserted for an agent.
        """

        memory_model1 = MemoryModel(
            agent_id=1,
            embedding_db=self.embedding_db,
            memory_db=self.db,
        )
        memory_model2 = MemoryModel(
            agent_id=2,
            embedding_db=self.embedding_db,
            memory_db=self.db,
        )
        memory_model1.persist(self.initial_memories[:4])
        memory_model2.persist(self.initial_memories[4:])

        total_memories = memory_model1.num_memories + memory_model2.num_memories
        total_embeddings = memory_model1.num_embeddings + memory_model2.num_embeddings

        assert (total_memories == total_embeddings) and (
            total_memories == len(self.initial_memories)
        )

    def test_working_memory_capacity(self):
        """
        Checks that the number of memories in the working memory does not exceed the working
        memory capacity.
        """

        MAX_TOKENS = 40
        memory_model = MemoryModel(
            agent_id=1,
            embedding_db=self.embedding_db,
            memory_db=self.db,
            retrieval_capacity=len(self.initial_memories),
            working_memory_token_capacity=MAX_TOKENS,
        )
        memory_model.persist(self.initial_memories)
        memory_model.refresh_working_memory(
            context="any context",
        )

        assert memory_model.working_memory_size <= MAX_TOKENS and len(
            memory_model.working_memory_buffer
        ) < len(self.initial_memories)

    def test_filter_per_application_session(self):
        """
        Checks whether the filter per application session works when retrieving memories.
        """
        chat_memory1 = ChatMemory(
            memory_id=len(self.initial_memories) + 1,
            creation_timestamp=datetime.strptime("2020-01-02 02:50:00", MESSAGE_DATETIME_FORMAT),
            content="Hi, I am happy to talk to you.",
            chat_id=1,
            session_id="abc",
        )
        chat_memory2 = ChatMemory(
            memory_id=len(self.initial_memories) + 2,
            creation_timestamp=datetime.strptime("2021-02-01 02:50:00", MESSAGE_DATETIME_FORMAT),
            content="I am not sure what to do next. Can you help me figure it out?",
            chat_id=1,
            session_id="bcd",
        )
        chat_memory3 = ChatMemory(
            memory_id=len(self.initial_memories) + 3,
            creation_timestamp=datetime.strptime("2021-02-01 02:50:00", MESSAGE_DATETIME_FORMAT),
            content="I am going to the park tomorrow.",
            chat_id=1,
            session_id="abc",
        )
        chat_memory4 = ChatMemory(
            memory_id=len(self.initial_memories) + 4,
            creation_timestamp=datetime.strptime("2021-02-01 02:50:00", MESSAGE_DATETIME_FORMAT),
            content="I just need some time off, honestly.",
            chat_id=2,  # different chat
            session_id="bcd",
        )

        # We create a model with enough capacity to retrieve all the memories, but it should
        # return just the two ones from the chat 1 and session abc.
        memory_model = MemoryModel(
            agent_id=1,
            embedding_db=self.embedding_db,
            memory_db=self.db,
            retrieval_capacity=len(self.initial_memories) + 4,
            remember_previous_sessions=False,
        )
        memory_model.persist(self.initial_memories)
        memory_model.persist([chat_memory1, chat_memory2, chat_memory3, chat_memory4])

        # Those two memories come from the same chat and session.
        expected_memories = [chat_memory1, chat_memory3]

        memory_model.refresh_working_memory(
            context="any context",
            application_id=1,
            session_id="abc",
        )

        num_retrieved_memories = len(memory_model.working_memory_buffer)
        retrieved_chat_memories = [
            m for m in memory_model.working_memory_buffer if isinstance(m, ChatMemory)
        ]
        chat_ids = set([m.chat_id for m in retrieved_chat_memories])
        session_ids = set([m.session_id for m in retrieved_chat_memories])

        # The expected number is all the life memories plus the 2 memories from chat 1 and session
        # abc
        assert num_retrieved_memories == len(expected_memories) + len(self.initial_memories)
        assert len(retrieved_chat_memories) == 2
        assert chat_ids == {1}
        assert session_ids == {"abc"}
