# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

import warnings
from typing import List, Optional

from camel.memories.base import AgentMemory, BaseContextCreator
from camel.memories.blocks import ChatHistoryBlock, VectorDBBlock
from camel.memories.records import ContextRecord, MemoryRecord
from camel.storages.key_value_storages.base import BaseKeyValueStorage
from camel.storages.vectordb_storages.base import BaseVectorStorage
from camel.types import OpenAIBackendRole


class ChatHistoryMemory(AgentMemory):
    r"""An agent memory wrapper of :obj:`ChatHistoryBlock`.

    Args:
        context_creator (BaseContextCreator): A model context creator.
        storage (BaseKeyValueStorage, optional): A storage backend for storing
            chat history. If `None`, an :obj:`InMemoryKeyValueStorage`
            will be used. (default: :obj:`None`)
        window_size (int, optional): The number of recent chat messages to
            retrieve. If not provided, the entire chat history will be
            retrieved.  (default: :obj:`None`)
        agent_id (str, optional): The ID of the agent associated with the chat
            history.
    """

    def __init__(
        self,
        context_creator: BaseContextCreator,
        storage: Optional[BaseKeyValueStorage] = None,
        window_size: Optional[int] = None,
        agent_id: Optional[str] = None,
    ) -> None:
        if window_size is not None and not isinstance(window_size, int):
            raise TypeError("`window_size` must be an integer or None.")
        if window_size is not None and window_size < 0:
            raise ValueError("`window_size` must be non-negative.")
        self._context_creator = context_creator
        self._window_size = window_size
        self._chat_history_block = ChatHistoryBlock(storage=storage)
        self._agent_id = agent_id

    @property
    def agent_id(self) -> Optional[str]:
        return self._agent_id

    @agent_id.setter
    def agent_id(self, val: Optional[str]) -> None:
        self._agent_id = val

    def retrieve(self) -> List[ContextRecord]:
        records = self._chat_history_block.retrieve(self._window_size)
        if self._window_size is not None and len(records) == self._window_size:
            warnings.warn(
                f"Chat history window size limit ({self._window_size}) "
                f"reached. Some earlier messages will not be included in "
                f"the context. Consider increasing window_size if you need "
                f"a longer context.",
                UserWarning,
                stacklevel=2,
            )
        return records

    def write_records(self, records: List[MemoryRecord]) -> None:
        for record in records:
            # assign the agent_id to the record
            if record.agent_id == "" and self.agent_id is not None:
                record.agent_id = self.agent_id
        self._chat_history_block.write_records(records)

    def get_context_creator(self) -> BaseContextCreator:
        return self._context_creator

    def clear(self) -> None:
        self._chat_history_block.clear()


class VectorDBMemory(AgentMemory):
    r"""An agent memory wrapper of :obj:`VectorDBBlock`. This memory queries
    messages stored in the vector database. Notice that the most recent
    messages will not be added to the context.

    Args:
        context_creator (BaseContextCreator): A model context creator.
        storage (BaseVectorStorage, optional): A vector storage storage. If
            `None`, an :obj:`QdrantStorage` will be used.
            (default: :obj:`None`)
        retrieve_limit (int, optional): The maximum number of messages
            to be added into the context.  (default: :obj:`3`)
        agent_id (str, optional): The ID of the agent associated with
            the messages stored in the vector database.
    """

    def __init__(
        self,
        context_creator: BaseContextCreator,
        storage: Optional[BaseVectorStorage] = None,
        retrieve_limit: int = 3,
        agent_id: Optional[str] = None,
    ) -> None:
        self._context_creator = context_creator
        self._retrieve_limit = retrieve_limit
        self._vectordb_block = VectorDBBlock(storage=storage)
        self._agent_id = agent_id

        self._current_topic: str = ""

    @property
    def agent_id(self) -> Optional[str]:
        return self._agent_id

    @agent_id.setter
    def agent_id(self, val: Optional[str]) -> None:
        self._agent_id = val

    def retrieve(self) -> List[ContextRecord]:
        return self._vectordb_block.retrieve(
            self._current_topic,
            limit=self._retrieve_limit,
        )

    def write_records(self, records: List[MemoryRecord]) -> None:
        # Assume the last user input is the current topic.
        for record in records:
            if record.role_at_backend == OpenAIBackendRole.USER:
                self._current_topic = record.message.content

            # assign the agent_id to the record
            if record.agent_id == "" and self.agent_id is not None:
                record.agent_id = self.agent_id

        self._vectordb_block.write_records(records)

    def get_context_creator(self) -> BaseContextCreator:
        return self._context_creator

    def clear(self) -> None:
        r"""Removes all records from the vector database memory."""
        self._vectordb_block.clear()


class LongtermAgentMemory(AgentMemory):
    r"""An implementation of the :obj:`AgentMemory` abstract base class for
    augmenting ChatHistoryMemory with VectorDBMemory.

    Args:
        context_creator (BaseContextCreator): A model context creator.
        chat_history_block (Optional[ChatHistoryBlock], optional): A chat
            history block. If `None`, a :obj:`ChatHistoryBlock` will be used.
            (default: :obj:`None`)
        vector_db_block (Optional[VectorDBBlock], optional): A vector database
            block. If `None`, a :obj:`VectorDBBlock` will be used.
            (default: :obj:`None`)
        retrieve_limit (int, optional): The maximum number of messages
            to be added into the context.  (default: :obj:`3`)
        agent_id (str, optional): The ID of the agent associated with the chat
            history and the messages stored in the vector database.
    """

    def __init__(
        self,
        context_creator: BaseContextCreator,
        chat_history_block: Optional[ChatHistoryBlock] = None,
        vector_db_block: Optional[VectorDBBlock] = None,
        retrieve_limit: int = 3,
        agent_id: Optional[str] = None,
    ) -> None:
        self.chat_history_block = chat_history_block or ChatHistoryBlock()
        self.vector_db_block = vector_db_block or VectorDBBlock()
        self.retrieve_limit = retrieve_limit
        self._context_creator = context_creator
        self._current_topic: str = ""
        self._agent_id = agent_id

    @property
    def agent_id(self) -> Optional[str]:
        return self._agent_id

    @agent_id.setter
    def agent_id(self, val: Optional[str]) -> None:
        self._agent_id = val

    def get_context_creator(self) -> BaseContextCreator:
        r"""Returns the context creator used by the memory.

        Returns:
            BaseContextCreator: The context creator used by the memory.
        """
        return self._context_creator

    def retrieve(self) -> List[ContextRecord]:
        r"""Retrieves context records from both the chat history and the vector
        database.

        Returns:
            List[ContextRecord]: A list of context records retrieved from both
                the chat history and the vector database.
        """
        chat_history = self.chat_history_block.retrieve()
        vector_db_retrieve = self.vector_db_block.retrieve(
            self._current_topic,
            self.retrieve_limit,
        )
        return chat_history[:1] + vector_db_retrieve + chat_history[1:]

    def write_records(self, records: List[MemoryRecord]) -> None:
        r"""Converts the provided chat messages into vector representations and
        writes them to the vector database.

        Args:
            records (List[MemoryRecord]): Messages to be added to the vector
                database.
        """
        self.vector_db_block.write_records(records)
        self.chat_history_block.write_records(records)

        for record in records:
            if record.role_at_backend == OpenAIBackendRole.USER:
                self._current_topic = record.message.content

    def clear(self) -> None:
        r"""Removes all records from the memory."""
        self.chat_history_block.clear()
        self.vector_db_block.clear()
