import json

from lita.core.protos import Message


class Memory:
    id: str = None
    messages: list[Message] = []

    max_messages: int = 100
    context_window_tokens: int = 128_000
    total_tokens: int = 0

    def __init__(
        self,
        id: str = None,
        max_messages: int = 100,
        context_window_tokens: int = 128_000,
    ) -> None:
        self.id = id
        self.max_messages = max_messages
        self.context_window_tokens = context_window_tokens
        self.messages = []
        self.total_tokens = 0

    def set_id(self, session_id: str) -> None:
        self.session_id = session_id

    def add_message(
        self,
        message: Message = None,
        role: str = None,
        content: str = None,
        tool_calls: list[dict] | None = None,
        tool_call_id: str | None = None,
    ) -> None:
        if message is not None and isinstance(message, Message):
            self.messages.append(message)
        else:
            m = Message(
                role=role,
                content=content,
                tool_calls=tool_calls,
                tool_call_id=tool_call_id,
            )
            self.messages.append(m)

    def get_messages(self) -> list[dict]:
        include_fields = ["role", "content", "tool_calls", "tool_call_id"]
        # This returned `msgs` is a copy, not a reference of `self.messages`
        msgs = [msg.json(exclude_none=True) for msg in self.messages]
        for msg in msgs:
            # Check by key
            # Use copy to avoid dict change during iteration
            for field in msg.copy():
                if field not in include_fields:
                    del msg[field]
        return msgs

    def get_full_messages(self) -> list[dict]:
        return [msg.json(exclude_none=True) for msg in self.messages]

    def save(self, save_path: str):
        with open(save_path, "w") as f:
            for msg in self.messages:
                f.write(json.dumps(msg.json(exclude_none=True)) + "\n")

    def reset_summarization(self) -> None:
        head_idx = -1
        end_idx = len(self.messages) - 1
        for msg in self.messages:
            if msg.role == "system":
                head_idx += 1
            elif msg.role == "user":
                head_idx += 1
            else:
                break
        self.messages = self.messages[: head_idx + 1] + self.messages[end_idx :]
