import datetime
import logging
from RoboMemory.Datas.ModuleLogger import ModuleLogger
from RoboMemory.Modules.Memories.LTM.LongTermMem import LongTermMemory, Extractor, MemoryGenerator, LTMAggregator, get_document_count
from RoboMemory.BaseModules.agent_general import GeneralAsyncAgent
from langchain_core.documents import Document
from RoboMemory.agent_utils import ModelConfig, VectorDBConfig


class EpisodicMemoryAggregator(LTMAggregator):
 
    def __init__(self):
        super().__init__()

    def aggregate(self, info: list[Document]) -> str:
        info.sort(key=lambda x: x.id if x.id is not None else -1)  # Ensure sorting by ID
        rslt = ''
        for doc in info:
            rslt += f"""Time:{doc.id}.
                    {doc.page_content}\n"""
        return rslt.strip()  # Remove trailing newline
    



class AutobioGenerator(MemoryGenerator):
  
    def __init__(self, model_config, template_path = "RoboMemory/Templates/prompts/LTM_Prompts/EMemAutobio.prompt"):
        super().__init__(model_config, template_path)
    
    async def generate(self, e: str):
        rslt = await self.async_create_completion(
            params={"experience": e})
        return rslt 
 
 
class EpisodicMemory(LongTermMemory):

    def __init__(
            self, 
            aggregator : EpisodicMemoryAggregator,
            autobio_generator: AutobioGenerator,
            embedding_config: ModelConfig,
            vectordb_config: VectorDBConfig,
            retrieve_k_task: int = 5,
            log_path: str="./ckpt"
        ) -> None:
        dummy_extractor = Extractor(model_config=embedding_config)
        dummy_updater = autobio_generator
        super().__init__(updater=dummy_updater, aggregator=aggregator, extractor=dummy_extractor,
                         memory_generator=autobio_generator, embedding_config=embedding_config,
                         vectordb_config=vectordb_config, retrieve_k_task=retrieve_k_task)

        self.logger = ModuleLogger(ckpt_path=log_path, record_name="EpsodicMemory" + datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
        
    def _create_entry(self, mem_str, action = '', task = '') -> Document:
   
        return Document(
            page_content=mem_str,
            id = get_document_count(self._vectorstore) + 1,  # Ensure unique ID
            metadata={
                'task': task,      
            }
        )

    async def update(self, infos: dict):
    
        infos = infos[self.name]  
 
        if not infos:
            logging.warning("No inner dictionary value to key 'EpsodicMemory'")
            return
        
        if 'iterations' not in infos or 'task' not in infos:
            logging.warning("Input dictionary must contain 'task' and 'iterations' keys.")
            return
        
        task = infos.get("task", "")

        logging.info(f"Updating episodic memory for task: {task}")

        # Generate the autobio memory string
        mem_str = await self._memory_generator.generate(f"Report of tackling task {task}:\n" + infos['iterations'])

        self.logger.log(f"""Updating episodic memory for task: {task}
                        Generated autobio memory: 
                        {mem_str}""")
        # Create a new Document entry
        new_entry = self._create_entry(mem_str, task=task)

        # Add the new entry to the vector store
        self._vectorstore.add_documents([new_entry])

    def retrieve(self, queries: dict) -> str:
    
        query_lst = queries[self.name]
        query = '\n'.join(query_lst).strip()
        docs = self._vectorstore.similarity_search(query, k=self._retrieval_k_task)
        return self.aggregator.aggregate(docs)