# Description: Memory module for the agent. This module is responsible for storing and managing the agent's memory.
# Modified from AgentVerse (https://github.com/OpenBMB/AgentVerse/tree/4dd772de18ac5fb2ed9b8cf4a8731565d224afab/agentverse/memory)

from abc import abstractmethod
from pydantic import BaseModel, Field
import copy
from typing import Any, List, Optional, Tuple, Dict, Set
from pydantic import Field

from config.const import CONFIGS

from .registry import Registry

from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
from prompts.summary_prompt import SUMMARIZATION_PROMPT as summary_prompt

memory_registry = Registry(name="MemoryRegistry")


class Message(BaseModel):
    content: Any = Field(default="")
    sender: str = Field(default="")
    receiver: Set[str] = Field(default=set({"all"}))
    sender_agent: object = Field(default=None)
    tool_response: List[Tuple[str]] = Field(default=[])  # todo


class BaseMemory(BaseModel):
    @abstractmethod
    def add_memory(self, messages: List[Message]) -> None:
        pass

    @abstractmethod
    def to_string(self) -> str:
        pass

    @abstractmethod
    def reset_memory(self) -> None:
        pass

    def get_memory(self) -> List[dict]:
        pass


@memory_registry.register("chat_history")  # todo
class ChatHistoryMemory(BaseMemory):
    my_name: str = ""
    messages: List[Message] = Field(default=[])
    has_summary: bool = False if "use_summary" not in CONFIGS["llm"] else CONFIGS["llm"]["use_summary"]
    max_summary_tlength: int = 500
    last_trimmed_index: int = 0
    summary: str = ""
    SUMMARIZATION_PROMPT: str = summary_prompt

    def add_memory(self, messages: List[Message]) -> None:
        self.messages.extend(messages)

    def to_string(self, add_sender_prefix: bool = False) -> str:
        if add_sender_prefix:
            return "\n".join(
                [
                    (
                        f"[{message.sender}]: {message.content}"
                        if message.sender != ""
                        else message.content
                    )
                    for message in self.messages
                ]
            )
        else:
            return "\n".join([message.content for message in self.messages])

    async def get_memory(
        self,
        start_index: int = 0,
        max_summary_length: int = 0,
        max_send_token: int = 0,
        model: str = "gpt-3.5-turbo",
        for_who: str = "all",
    ) -> List[dict]:
        messages = []

        if self.has_summary:
            start_index = self.last_trimmed_index

        for message in self.messages[start_index:]:
            sender_name = message.sender if message.sender != self.my_name else "you"
            if for_who != message.sender and (for_who not in message.receiver and "all" not in message.receiver):
                continue
            messages.append(
                {
                    "role": "assistant",
                    "content": f"[{sender_name}]: {message.content}",
                }
            )

        # summary message
        if self.has_summary:
            if max_summary_length == 0:
                max_summary_length = self.max_summary_tlength
            max_send_token -= max_summary_length
            prompt = []
            trimmed_history = add_history_upto_token_limit(
                prompt, messages, max_send_token, model
            )
            if trimmed_history:
                new_summary_msg, _ = await self.trim_messages(
                    list(prompt), model, messages
                )
                prompt.append(new_summary_msg)
            messages = prompt
        return messages

    def reset_memory(self) -> None:
        self.messages = []

    async def trim_messages(
        self, current_message_chain: List[Dict], model: str, history: List[Dict]
    ) -> Tuple[Dict, List[Dict]]:
        new_messages_not_in_chain = [
            msg for msg in history if msg not in current_message_chain
        ]

        if not new_messages_not_in_chain:
            return self.summary_message(), []

        new_summary_message = await self.update_running_summary(
            new_events=new_messages_not_in_chain, model=model
        )

        last_message = new_messages_not_in_chain[-1]
        self.last_trimmed_index += history.index(last_message)

        return new_summary_message, new_messages_not_in_chain

    async def update_running_summary(
        self,
        new_events: List[Dict],
        model: str = "gpt-3.5-turbo",
        max_summary_length: Optional[int] = None,
    ) -> dict:
        if not new_events:
            return self.summary_message()
        if max_summary_length is None:
            max_summary_length = self.max_summary_tlength

        new_events = copy.deepcopy(new_events)

        # Replace "assistant" with "you". This produces much better first person past tense results.
        for event in new_events:
            if event["role"].lower() == "assistant":
                event["role"] = "you"

            elif event["role"].lower() == "system":
                event["role"] = "your computer"

            # Delete all user messages
            elif event["role"] == "user":
                new_events.remove(event)

        prompt_template_length = len(
            self.SUMMARIZATION_PROMPT.format(summary="", new_events="")
        )
        from .llm import DEFAULT_CLIENT

        max_input_tokens = DEFAULT_CLIENT.send_token_limit() - max_summary_length
        summary_tlength = count_string_tokens(self.summary, model)
        batch: List[Dict] = []
        batch_tlength = 0

        for event in new_events:
            event_tlength = count_message_tokens(event, model)

            if (
                batch_tlength + event_tlength
                > max_input_tokens - prompt_template_length - summary_tlength
            ):
                await self._update_summary_with_batch(batch, model, max_summary_length)
                summary_tlength = count_string_tokens(self.summary, model)
                batch = [event]
                batch_tlength = event_tlength
            else:
                batch.append(event)
                batch_tlength += event_tlength

        if batch:
            await self._update_summary_with_batch(batch, model, max_summary_length)

        return self.summary_message()

    async def _update_summary_with_batch(
        self, new_events_batch: List[dict], max_summary_length: int
    ) -> None:
        prompt = self.SUMMARIZATION_PROMPT.format(
            summary=self.summary, new_events=new_events_batch  # todo event batch
        )
        from .llm import DEFAULT_CLIENT
        self.summary = await DEFAULT_CLIENT.aask(prompt)

    def summary_message(self) -> dict:
        return {
            "role": "system",
            "content": f"This reminds you of these events from your past: \n{self.summary}",
        }


def add_history_upto_token_limit(
    prompt: List[dict], history: List[dict], t_limit: int, model: str
) -> List[Message]:
    limit_reached = False
    current_prompt_length = 0
    trimmed_messages: List[Dict] = []
    for message in history[::-1]:
        token_to_add = count_message_tokens(message, model)
        if current_prompt_length + token_to_add > t_limit:
            limit_reached = True

        if not limit_reached:
            prompt.insert(0, message)
            current_prompt_length += token_to_add
        else:
            trimmed_messages.insert(0, message)
    return trimmed_messages


if __name__ == "__main__":
    memory = ChatHistoryMemory()
    message = Message(content="Hello", sender="user", receiver={"assistant"})
    print(message.receiver)
    memory.add_memory([message])
 
    