from datetime import datetime
from unittest import TestCase

import chromadb
from langchain.vectorstores import Chroma
from langchain_community.embeddings import SentenceTransformerEmbeddings
from sqlalchemy import MetaData, create_engine
from sqlalchemy.orm import Session

from synthetic_agents.common.config import MESSAGE_DATETIME_FORMAT
from synthetic_agents.database.functions import create_all_tables
from synthetic_agents.model.agent import Agent
from synthetic_agents.model.brain import Brain
from synthetic_agents.model.constants import HUGGING_FACE_EMBEDDING_MODEL
from synthetic_agents.model.entity.life_memory import LifeMemory
from synthetic_agents.model.entity.world import TextMessage, WorldState
from synthetic_agents.model.language import LanguageModel
from synthetic_agents.model.memory import MemoryModel
from synthetic_agents.prompt.builder import PromptBuilder
from synthetic_agents.prompt.loader import TextPromptTemplateLoader


class TestBrain(TestCase):
    def setUp(self) -> None:
        """
        Defines a brain model to be used in the test cases implemented in this script.
        """
        negative_memories = [
            LifeMemory(
                memory_id=1,
                creation_timestamp=datetime.strptime(
                    "2011-07-17 18:32:00", MESSAGE_DATETIME_FORMAT
                ),
                content="I got into a car accident and had to pay a lot of money to repair it",
            ),
            LifeMemory(
                memory_id=2,
                creation_timestamp=datetime.strptime(
                    "2014-10-08 22:54:20", MESSAGE_DATETIME_FORMAT
                ),
                content="I was passed up for a promotion that I had been working towards",
            ),
            LifeMemory(
                memory_id=3,
                creation_timestamp=datetime.strptime(
                    "2017-02-15 11:21:55", MESSAGE_DATETIME_FORMAT
                ),
                content="I had to take a pay cut in order to keep my job",
            ),
        ]

        # The fake agent will only be used to be the target of a message. The brain prepends its
        # name before the messages produced by him.
        self.fake_target_agent = Agent(
            agent_id=None,
            agent_type=None,
            application_type=None,
            agent_attributes={"name": "other"},
            prompt_builder=None,
            memory_db=None,
            memory_embedding_db=None,
            initial_world_state=WorldState(application_id=1, session_id="abc"),
        )

        # In-memory vector and relational DBs for testing.
        self.client = chromadb.Client()
        self.client.get_or_create_collection("test")
        self.embedding_database = Chroma(
            client=self.client,
            collection_name="test",
            embedding_function=SentenceTransformerEmbeddings(
                model_name=HUGGING_FACE_EMBEDDING_MODEL
            ),
        )

        self.database_engine = create_engine("sqlite://")
        create_all_tables(self.database_engine)

        memory_model = MemoryModel(
            agent_id=1,
            embedding_db=self.embedding_database,
            memory_db=Session(self.database_engine),
            retrieval_capacity=len(negative_memories),
        )
        memory_model.persist(negative_memories)

        prompt_template_loader = TextPromptTemplateLoader(template="")
        prompt_builder = PromptBuilder(
            placeholder_values={}, prompt_template_loader=prompt_template_loader
        )
        language_model = LanguageModel(system_prompt_builder=prompt_builder)

        self.brain = Brain(
            agent_name="Test",
            language_model=language_model,
            memory_model=memory_model,
            context_window_capacity=3,
            initial_world_state=WorldState(application_id=1, session_id="abc"),
        )

    def tearDown(self) -> None:
        """
        Deletes the test collection in the vector DB such a change in the DB made by a test case
        doesn't affect another.
        """
        self.client.delete_collection("test")

        # Drop all tables in the relational database
        metadata = MetaData()
        metadata.reflect(self.database_engine)
        metadata.drop_all(self.database_engine)

    def test_brain_context_window_capacity(self):
        """
        Checks that the brain context window caps information according to its capacity.
        """
        NUM_UPDATES = 5

        for i in range(NUM_UPDATES):
            self.brain.perceive_world(
                WorldState(
                    application_id=1,
                    session_id="abc",
                    last_text_message=TextMessage(
                        creation_timestamp=datetime.now(),
                        content=f"Context {i}",
                        agent_name=self.fake_target_agent.name,
                        agent_type=self.fake_target_agent.agent_type,
                    ),
                )
            )

        # Context 0 and 1 are removed from the brain's context because its window capacity is only
        # 3
        expected_context = "Context 2, Context 3, Context 4"
        context = ", ".join([c.last_text_message.content for c in self.brain._context])

        assert context == expected_context

    def test_brain_world_state_persistence_during_world_perception(self):
        """
        Checks that the world state perceived by the brain is saved on its long-term memory
        storage.
        """
        NUM_UPDATES = 5

        num_initial_memories = self.brain.memory_model.num_memories
        for i in range(NUM_UPDATES):
            self.brain.perceive_world(
                WorldState(
                    application_id=1,
                    session_id="abc",
                    last_text_message=TextMessage(
                        creation_timestamp=datetime.now(),
                        content=f"Context {i}",
                        agent_name=self.fake_target_agent.name,
                        agent_type=self.fake_target_agent.agent_type,
                    ),
                )
            )
        num_final_memories = self.brain.memory_model.num_memories

        assert num_final_memories == NUM_UPDATES + num_initial_memories

    def test_brain_speech_persistence(self):
        """
        Checks that the speech produced by the brain is saved on its long-term memory storage.
        """
        NUM_UPDATES = 5

        self.brain.initialize_brain()
        num_initial_memories = self.brain.memory_model.num_memories
        for i in range(NUM_UPDATES):
            self.brain.speak()
        num_final_memories = self.brain.memory_model.num_memories

        assert num_final_memories == NUM_UPDATES + num_initial_memories
