from datetime import datetime
from unittest import TestCase

from synthetic_agents.common.config import MESSAGE_DATETIME_FORMAT
from synthetic_agents.model.entity.chat_memory import ChatMemory
from synthetic_agents.model.entity.life_memory import LifeMemory
from synthetic_agents.model.language import LanguageModel
from synthetic_agents.prompt.builder import PromptBuilder
from synthetic_agents.prompt.loader import TextPromptTemplateLoader


class TestLanguageModel(TestCase):
    def setUp(self) -> None:
        """
        Defines a agent attributes and memories to be used by a language model.
        """
        self.agent_attributes = {
            "demographics": "age = 30, gender = male",
            "personality": "anxious",
        }
        self.memories = [
            LifeMemory(
                memory_id=1,
                creation_timestamp=datetime.strptime(
                    "2013-04-09 02:50:00", MESSAGE_DATETIME_FORMAT
                ),
                content="I graduated from college with honors",
            ),
            LifeMemory(
                memory_id=2,
                creation_timestamp=datetime.strptime(
                    "2016-08-13 16:47:12", MESSAGE_DATETIME_FORMAT
                ),
                content="I was promoted to a new position at work",
            ),
            ChatMemory(
                memory_id=3,
                creation_timestamp=datetime.strptime(
                    "2018-10-04 10:17:15", MESSAGE_DATETIME_FORMAT
                ),
                content="C: How are you today?",
                chat_id=1,
                session_id="abc",
            ),
            ChatMemory(
                memory_id=4,
                creation_timestamp=datetime.strptime(
                    "2021-04-20 22:20:53", MESSAGE_DATETIME_FORMAT
                ),
                content="A: I am doing well.",
                chat_id=1,
                session_id="abc",
            ),
        ]

    def test_system_prompt(self):
        """
        Checks that the prompt passed to the LLM is correct.
        """

        template = """
You are a human with demographics, personality and memories as defined below. You
 are chatting to another human. Respond to the last message in the chat.

Demographics:
{demographics}

Personality:
{personality}

Life Memories:
{life_memories}

Chat Memories:
{chat_memories}
"""
        prompt_template_loader = TextPromptTemplateLoader(template=template)
        prompt_builder = PromptBuilder(
            placeholder_values=self.agent_attributes, prompt_template_loader=prompt_template_loader
        )

        language_model = LanguageModel(system_prompt_builder=prompt_builder)
        language_model.initialize_model()
        language_model.set_memories(self.memories)
        language_model.set_last_message("C: How can I help you?")

        expected_prompt = """
You are a human with demographics, personality and memories as defined below. You
 are chatting to another human. Respond to the last message in the chat.

Demographics:
age = 30, gender = male

Personality:
anxious

Life Memories:
I graduated from college with honors
I was promoted to a new position at work

Chat Memories:
C: How are you today?
A: I am doing well.
"""

        assert language_model.build_system_message() == expected_prompt

    def test_language_production(self):
        """
        Checks that the language model can produce a message according to a instruction.
        """

        template = """
Answer the question with a thank you as in the examples below.

Example Input 1:

> Name
Alex

Example Output 1:
Thank you, Alex!

Example Input 2:

> Name
Ana

Example Output 2:
Thank you, Ana!

Input:

> Name
"""
        prompt_template_loader = TextPromptTemplateLoader(template=template)
        prompt_builder = PromptBuilder(
            placeholder_values=self.agent_attributes, prompt_template_loader=prompt_template_loader
        )

        language_model = LanguageModel(system_prompt_builder=prompt_builder, text_streaming=False)
        language_model.set_last_message("Tricia")
        language_model.initialize_model()

        expected_response = "Thank you, Tricia!"

        assert language_model.generate_text()["text"].strip() == expected_response
