import colmind.utils as U

from langchain_openai import OpenAIEmbeddings
from langchain.schema import HumanMessage, SystemMessage
from langchain_community.vectorstores import Chroma
from langchain_openai import ChatOpenAI
from colmind.agents.llm import get_llm

class EpisodicManager:
    def __init__(
        self,
        config,
        username,
        logger,
        resume=False,
    ):
        self.config = config
        self.username = username
        self.logger = logger

        U.f_mkdir(f"{config['ckpt_dir']}/episodic/vectordb")

        self.llm = get_llm(config["parameters"]["llm"])
        self.vectordb = Chroma(
            collection_name="skill_vectordb",
            embedding_function=OpenAIEmbeddings(),
            persist_directory=f"{config['ckpt_dir']}/episodic/vectordb",
        )

    def add_new_episode(self, info):
        episode: str = info["episode"]
        self.vectordb.add_texts(
            texts=[episode],
            ids=[f"episode_{self.vectordb._collection.count()}"],
            metadata=[{"episode": self.vectordb._collection.count()}],
        )
        self.vectordb.persist()

    def generate_summary(self, episodes: list):
        if episodes == []:
            return "There are no past failure episodes."
        # Combine episodes into a single string
        combined_episodes = "\n\n".join(episodes)
        # Create prompt for summarization
        system_message = SystemMessage(
            content="You are a helpful assistant tasked with summarizing past experience episodes and pointing out the causes of failure. Create a concise summary."
        )
        human_message = HumanMessage(
            content=f"Please summarize these episodes and why they failed:\n\n{combined_episodes}"
        )
        # Generate summary
        messages = [system_message, human_message]
        summary = self.llm.invoke(messages).content
        return summary

    def retrieve_episodes(self, query: str):
        k = min(self.vectordb._collection.count(), self.config["parameters"]["retrieval_top_k"])
        if k == 0:
            return []
        self.logger.info(f"Retrieving top {k} episodes for query: {query}")
        docs_and_scores = self.vectordb.similarity_search_with_score(query, k=k)
        self.logger.info(
            f"{self.username} Episode Manager retrieved episodes: "
            f"{', '.join([doc.page_content for doc, _ in docs_and_scores])}"
        )
        episodes = []
        for doc, _ in docs_and_scores:
            episodes.append(doc.page_content)
        return episodes
