"""RAG pipeline"""

import asyncio

from pydantic import BaseModel

from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH
from metagpt.logs import logger
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import (
    ChromaIndexConfig,
    ChromaRetrieverConfig,
    ElasticsearchIndexConfig,
    ElasticsearchRetrieverConfig,
    ElasticsearchStoreConfig,
    FAISSRetrieverConfig,
    LLMRankerConfig,
)
from metagpt.utils.exceptions import handle_exception

LLM_TIP = "If you not sure, just answer I don't know."

DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt"
QUESTION = f"What are key qualities to be a good writer? {LLM_TIP}"

TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt"
TRAVEL_QUESTION = f"What does Bob like? {LLM_TIP}"


class Player(BaseModel):
    """To demonstrate rag add objs."""

    name: str = ""
    goal: str = "Win The 100-meter Sprint."
    tool: str = "Red Bull Energy Drink."

    def rag_key(self) -> str:
        """For search"""
        return self.goal


class RAGExample:
    """Show how to use RAG."""

    def __init__(self, engine: SimpleEngine = None, use_llm_ranker: bool = True):
        self._engine = engine
        self._use_llm_ranker = use_llm_ranker

    @property
    def engine(self):
        if not self._engine:
            ranker_configs = [LLMRankerConfig()] if self._use_llm_ranker else None

            self._engine = SimpleEngine.from_docs(
                input_files=[DOC_PATH],
                retriever_configs=[FAISSRetrieverConfig()],
                ranker_configs=ranker_configs,
            )
        return self._engine

    @engine.setter
    def engine(self, value: SimpleEngine):
        self._engine = value

    @handle_exception
    async def run_pipeline(self, question=QUESTION, print_title=True):
        """This example run rag pipeline, use faiss retriever and llm ranker, will print something like:

        Retrieve Result:
        0. Productivi..., 10.0
        1. I wrote cu..., 7.0
        2. I highly r..., 5.0

        Query Result:
        Passion, adaptability, open-mindedness, creativity, discipline, and empathy are key qualities to be a good writer.
        """
        if print_title:
            self._print_title("Run Pipeline")

        nodes = await self.engine.aretrieve(question)
        self._print_retrieve_result(nodes)

        answer = await self.engine.aquery(question)
        self._print_query_result(answer)

    @handle_exception
    async def add_docs(self):
        """This example show how to add docs.

        Before add docs llm anwser I don't know.
        After add docs llm give the correct answer, will print something like:

        [Before add docs]
        Retrieve Result:

        Query Result:
        Empty Response

        [After add docs]
        Retrieve Result:
        0. Bob like..., 10.0

        Query Result:
        Bob likes traveling.
        """
        self._print_title("Add Docs")

        travel_question = f"{TRAVEL_QUESTION}"
        travel_filepath = TRAVEL_DOC_PATH

        logger.info("[Before add docs]")
        await self.run_pipeline(question=travel_question, print_title=False)

        logger.info("[After add docs]")
        self.engine.add_docs([travel_filepath])
        await self.run_pipeline(question=travel_question, print_title=False)

    @handle_exception
    async def add_objects(self, print_title=True):
        """This example show how to add objects.

        Before add docs, engine retrieve nothing.
        After add objects, engine give the correct answer, will print something like:

        [Before add objs]
        Retrieve Result:

        [After add objs]
        Retrieve Result:
        0. 100m Sprin..., 10.0

        [Object Detail]
        {'name': 'Mike', 'goal': 'Win The 100-meter Sprint', 'tool': 'Red Bull Energy Drink'}
        """
        if print_title:
            self._print_title("Add Objects")

        player = Player(name="Mike")
        question = f"{player.rag_key()}"

        logger.info("[Before add objs]")
        await self._retrieve_and_print(question)

        logger.info("[After add objs]")
        self.engine.add_objs([player])

        try:
            nodes = await self._retrieve_and_print(question)

            logger.info("[Object Detail]")
            player: Player = nodes[0].metadata["obj"]
            logger.info(player.name)
        except Exception as e:
            logger.error(f"nodes is empty, llm don't answer correctly, exception: {e}")

    @handle_exception
    async def init_objects(self):
        """This example show how to from objs, will print something like:

        Same as add_objects.
        """
        self._print_title("Init Objects")

        pre_engine = self.engine
        self.engine = SimpleEngine.from_objs(retriever_configs=[FAISSRetrieverConfig()])
        await self.add_objects(print_title=False)
        self.engine = pre_engine

    @handle_exception
    async def init_and_query_chromadb(self):
        """This example show how to use chromadb. how to save and load index. will print something like:

        Query Result:
        Bob likes traveling.
        """
        self._print_title("Init And Query ChromaDB")

        # 1. save index
        output_dir = DATA_PATH / "rag"
        SimpleEngine.from_docs(
            input_files=[TRAVEL_DOC_PATH],
            retriever_configs=[ChromaRetrieverConfig(persist_path=output_dir)],
        )

        # 2. load index
        engine = SimpleEngine.from_index(index_config=ChromaIndexConfig(persist_path=output_dir))

        # 3. query
        answer = await engine.aquery(TRAVEL_QUESTION)
        self._print_query_result(answer)

    @handle_exception
    async def init_and_query_es(self):
        """This example show how to use es. how to save and load index. will print something like:

        Query Result:
        Bob likes traveling.
        """
        self._print_title("Init And Query Elasticsearch")

        # 1. create es index and save docs
        store_config = ElasticsearchStoreConfig(index_name="travel", es_url="http://127.0.0.1:9200")
        engine = SimpleEngine.from_docs(
            input_files=[TRAVEL_DOC_PATH],
            retriever_configs=[ElasticsearchRetrieverConfig(store_config=store_config)],
        )

        # 2. load index
        engine = SimpleEngine.from_index(index_config=ElasticsearchIndexConfig(store_config=store_config))

        # 3. query
        answer = await engine.aquery(TRAVEL_QUESTION)
        self._print_query_result(answer)

    @staticmethod
    def _print_title(title):
        logger.info(f"{'#'*30} {title} {'#'*30}")

    @staticmethod
    def _print_retrieve_result(result):
        """Print retrieve result."""
        logger.info("Retrieve Result:")

        for i, node in enumerate(result):
            logger.info(f"{i}. {node.text[:10]}..., {node.score}")

        logger.info("")

    @staticmethod
    def _print_query_result(result):
        """Print query result."""
        logger.info("Query Result:")

        logger.info(f"{result}\n")

    async def _retrieve_and_print(self, question):
        nodes = await self.engine.aretrieve(question)
        self._print_retrieve_result(nodes)
        return nodes


async def main():
    """RAG pipeline.

    Note:
    1. If `use_llm_ranker` is True, then it will use LLM Reranker to get better result, but it is not always guaranteed that the output will be parseable for reranking,
       prefer `gpt-4-turbo`, otherwise might encounter `IndexError: list index out of range` or `ValueError: invalid literal for int() with base 10`.
    """
    e = RAGExample(use_llm_ranker=False)

    await e.run_pipeline()
    await e.add_docs()
    await e.add_objects()
    await e.init_objects()
    await e.init_and_query_chromadb()
    await e.init_and_query_es()


if __name__ == "__main__":
    asyncio.run(main())
