from __future__ import annotations

from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional, Union

import tiktoken

from synthetic_agents.common.config import MESSAGE_DATETIME_FORMAT
from synthetic_agents.database.entity.memory import Memory as DBMemory


class Memory(ABC):
    """
    This class represents a generic memory.
    """

    def __init__(
        self,
        named_type: str,
        creation_timestamp: datetime,
        content: str,
        memory_id: Optional[int] = None,
        similarity_score: float = 0,
        relevance_score: float = 0,
    ):
        """
        Creates a memory.

        :param named_type: name of the type of memory. Either "life" or "chat".
        :param creation_timestamp: when the memory was created.
        :param content: textual content of the memory.
        :param memory_id: ID of the memory. It is optional because not persisted memories do not
            have an ID yet.
        :param similarity_score: similarity score between the memory and query used to retrieve the
            memory.
        :param relevance_score: score attributed to relevance during the memory ranking phase.
        :raise: ValueError: raised if the category if not a valid one.
        """

        self.named_type = named_type
        self.creation_timestamp = creation_timestamp
        self.content = content
        self.memory_id = memory_id
        self.similarity_score = similarity_score
        self.relevance_score = relevance_score

    @classmethod
    def from_json(cls, memory_json: dict[str, Union[str, float, dict[str, float]]]) -> Memory:
        """
        Creates a memory from the contents of a json object.

        :param memory_json: json object containing attributes of the memory object.
        :return: a memory object.
        """
        # Expand the emotional state attributes as an EmotionalState object
        memory_json = memory_json.copy()
        if "memory_id" not in memory_json:
            memory_json["memory_id"] = None
        memory_json["creation_timestamp"] = datetime.strptime(
            memory_json["creation_timestamp"], MESSAGE_DATETIME_FORMAT
        )
        return cls(**memory_json)

    @property
    def metadata(self) -> dict[str, Union[str, int, datetime, dict[str, float]]]:
        """
        Gets the memory metadata (all attributes but content) to be saved in the embeddings
        database as a dictionary.

        :return: memory metadata.
        """
        metadata = {
            "memory_type": self.named_type,
            "memory_id": self.memory_id,
            "creation_timestamp": self.creation_timestamp.strftime(MESSAGE_DATETIME_FORMAT),
        }
        return metadata

    def estimate_num_tokens(self, model_name: str) -> int:
        """
        Estimates the number of tokens consumed by the memory.

        :param model_name: name of the LLM that will process the memory as part of the prompt.
        :return: number of tokens.
        """
        encoding = tiktoken.encoding_for_model(model_name)
        return len(encoding.encode(self.content))

    def __repr__(self) -> str:
        """
        Gets a textual representation of the memory containing its timestamp creation and content.
        :return: textual representation fo the memory.
        """
        template = "- {content}"
        return template.format(
            content=self.content,
        )

    @abstractmethod
    def to_persistent_object(self, agent_id: int) -> DBMemory:
        """
        Gets a relational DB representation of the memory for persistence. The concrete DB object
        depends on the type of memory so this method must be implemented by child memory classes.

        :param agent_id: id of the agent associated with the memory.
        :return: a persistent memory object.
        """
        pass

    def to_simplified_json(self, attributes: list[str]) -> dict[str, Union[str, int, float]]:
        """
        Returns a json representation of the memory with the subset of attributes specified.

        :param attributes: attributes to include.
        """

        simplified_json = {}
        for attribute in attributes:
            if hasattr(self, attribute):
                attribute_value = getattr(self, attribute)
                if isinstance(attribute_value, datetime):
                    attribute_value = datetime.strftime(attribute_value, MESSAGE_DATETIME_FORMAT)
                simplified_json[attribute] = attribute_value
            else:
                raise ValueError(f"The attribute {attribute} does not belong tp the Memory class.")

        return simplified_json

    @classmethod
    @abstractmethod
    def from_persistent_object(cls, db_memory: DBMemory) -> Memory:
        """
        Gets a memory object from its relational DB representation.

        :param db_memory: relational DB instance of a memory.
        :return: a memory object.
        """
        pass
