import datetime
from uuid import uuid4

from click import prompt
from matplotlib.pyplot import hist
from RoboMemory.Modules.Memories.LTM.LongTermMem import LongTermMemory, Extractor, MemoryGenerator, LTMAggregator
from RoboMemory.BaseModules.agent_general import GeneralAsyncAgent
from RoboMemory.agent_utils import ModelConfig, VectorDBConfig
from langchain_core.documents import Document
from RoboMemory.Datas.ModuleLogger import ModuleLogger
import logging
import time



# logger = logging.getLogger(__name__)

class SemanticMemAggregator(LTMAggregator):

    def __init__(self):
        super().__init__()
    
    def aggregate(self, info: list[Document]) -> str:
        
        rslt = '-'*50 + '\n'
        for doc in info:
            if doc.metadata['is_task'] == 'T':
                rslt += f"Task: {doc.metadata['task']}\n {doc.page_content}\n\n"
            else:
                rslt += f"Action: {doc.metadata['action']}\n{doc.page_content}\n\n"
        return rslt+ '-'*50 #                  

class SemanticMemUpdater(GeneralAsyncAgent):
    
    def __init__(self, model_config, template_path = "RoboMemory/Templates/prompts/LTM_Prompts/Classification.prompt"):
        super().__init__(model_config, template_path)

class SemanticMemGenerator(MemoryGenerator):
    
    def __init__(self, model_config, template_path = "RoboMemory/Templates/prompts/LTM_Prompts/SemanticTaskSum.prompt"):
        super().__init__(model_config, template_path)
    
    async def generate(self, e: str):
        rslt = await self.async_create_completion(
            params={"report": e})
        return rslt

class SemExtractor(Extractor):
   
    def __init__(self, model_config, template_path = "RoboMemory/templates/prompts/LTM_Prompts/ActionSum.prompt"):
        super().__init__(model_config, template_path)

class SemanticMemory(LongTermMemory):

    _api_discriptions: str

    def __init__(
            self, 
            updater : SemanticMemUpdater, 
            aggregator : SemanticMemAggregator,
            extractor : Extractor,
            memory_generator : SemanticMemGenerator,
            embedding_config: ModelConfig,
            vectordb_config: VectorDBConfig,
            retrieve_k_task: int = 5,
            retrieve_k_action: int = 5,
            top_k: int = 10,
            log_path: str="./ckpt"
        ):
        super().__init__(updater, aggregator, extractor, memory_generator, embedding_config, vectordb_config, retrieve_k_task)
        self._top_k = top_k
        self._retrieve_k_action = retrieve_k_action
        self.logger = ModuleLogger(ckpt_path=log_path, record_name="SemanticMemory" + datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
    
    # extraction phase
    async def _extract_facts(self, iterations: str) -> str:
      
        iter_lst  = iterations.strip().split("\n\n")
        step_sum = iter_lst[-1]
        history = '\n'.join(iter_lst[:-1])  #     
        result = await self._extractor.async_create_completion(
            params={
                "history": history,
                "report": step_sum,
                "apis": self._api_descriptions
            })
        # print(f"Extracted facts: {result}")
        return str(result)
    
    @staticmethod
    def _actionSum_format(action: str) -> str:
      
        idx = action.find('Semantic Memory:')
        if idx == -1:
            return ''
        rslt = action[idx + len('Semantic Memory:'):].strip()
        return rslt

    async def update(self, infos: dict) -> str:
     
        

        try:
            infos = infos[self.name] #      key     information
        except:
            logging.warning("No SemanticMemory, No update") 
            return ''
        
        if not infos:
            logging.warning("'SemanticMemory': Empty dictionary")
            return ''
        
        if 'iterations' not in infos:
            logging.warning("Input dictionary must contain 'iterations' keys.")
            return ''
        if 'task' not in infos and 'action' not in infos:
            logging.warning("Input dictionary must contain 'task'/'action' keys.")
            return ''
        self._api_descriptions = infos.get('api_descriptions', '')
        task = infos.get("task", "")
        if task:
            logging.info("Semantic Memory: task summary") #   console log

            
            task_str = infos.get("iterations", '').strip()
            recent_mem_str = '\n'.join(self._recent_mem)
            prompt = f"Task: {task}.\nLog:\n{task_str}\n\nReflection of Failed Actions:{recent_mem_str}\n\nCapable Actions:{self._api_descriptions}"
            task_sum = await self._memory_generator.generate(prompt)  #       
            
            self.logger.log(f"{prompt}\n\nTask Summary:\n{task_sum}", module_step=self._num_iter)  #       

            task_sum = self._actionSum_format(task_sum)
            if task_sum:
                o = await self._update(task_sum, action='', task=task)
                self.logger.log(f"Opperation Identified: {o}", module_step=self._num_iter)  #       
                return o
            
            return ''
        
        else:
            logging.info("Semantic Memory: action summary") #   console log

            iterations = infos.get("iterations", "")
            
            action = infos.get("action", "")
            
            knowledge_str = await self._extract_facts(iterations) #   action      
            step_mem = self._actionSum_format(knowledge_str)

            self.logger.log(f"""Action Executed: {action}.\nExecution Report:\n{'-'*50}\n{iterations}\n{'-'*50}\nllm respsonse: {knowledge_str}\nStep Summary: {step_mem}\n""", 
                            module_step=self._num_iter)  #         

            
            # step summary
            if not step_mem:
                self._num_iter += 1
                return '' #            

            self._recent_mem.append(step_mem)  #     action     task  
            o = await self._update(step_mem, action=action) #   
            
            
            self.logger.log(f"Opperation Identified: {o}", module_step=self._num_iter)  #       

            self._num_iter += 1
            return o
            
        
    # update phase
    def _create_entry(self, mem_str, action = '', task = '') -> Document:
       
        doc = Document(
            page_content= task + action + ':\n' + mem_str,
            id = str(uuid4()),  #   UUID    ID
            metadata={
                "doc_id": str(self._num_iter),
                "action": action,
                "task": task,
                "is_task": 'T' if task else 'F' #           
            }
        )
        return doc
    
    async def _classify_operation(self, new_mem_str: str, similar_texts: str):
       
        params = {
            "new_mem_str": new_mem_str,
            "similar_texts": similar_texts
        }
        return await self.updater.async_create_completion(params=params)
    
    async def _update(self, new_mem_str:str, action: str = '', task: str = ''):
       
        if action:
            action = action[:action.find('(')]  #           
            similar_docs = self._vectorstore.similarity_search(new_mem_str, self._top_k, filter={"action": action})
            similar_texts = '/n'.join([doc.page_content for doc in similar_docs])
            operation = await self._classify_operation(f"Report of executing action {action}:\n{new_mem_str}", similar_texts)

        else:
            similar_docs = self._vectorstore.similarity_search(new_mem_str, self._top_k, filter={"is_task": 'T'})
            similar_texts = '/n'.join([doc.page_content for doc in similar_docs])
            operation = await self._classify_operation(f"Report of tackling task {task}:\n{new_mem_str}", similar_texts)

       
        if operation == "ADD":
            entry = self._create_entry(new_mem_str, action=action, task=task)
            self._vectorstore.add_documents([entry])

        elif operation == "UPDATE" and similar_docs:
           
            old_id = similar_docs[0].id
            new_doc = self._create_entry(new_mem_str, action=action, task=task)
            if old_id:
                self._vectorstore.update_document(
                    document_id= old_id,
                    document=new_doc
                )
            else:
                logging.warning("No document ID found for update operation, adding as new document.")
                self._vectorstore.add_documents([new_doc])

        elif operation == "DELETE" and similar_docs:
    
            old_id = similar_docs[0].id
            if old_id:
                self._vectorstore.delete([old_id])
            else:
                logging.warning("No document ID found for delete operation, cannot delete document.")
            new_doc = self._create_entry(new_mem_str, action=action, task=task)
            self._vectorstore.add_documents([new_doc])

        # NOOP does nothing
        return operation


    def retrieve(self, queries: dict) -> str:
      
        query_lst = queries[self.name]
        query = '\n'.join(query_lst).strip()
        action_docs = self._vectorstore.similarity_search(query, k=self._retrieve_k_action, filter={"is_task": 'F'})
    
        task_docs = self._vectorstore.similarity_search(query, k=self._retrieval_k_task, filter={"is_task": 'T'})
        
        rslt = f"{self.aggregator.aggregate(task_docs)}\n\n{self.aggregator.aggregate(action_docs)}"

        self.logger.log(f'Queries:{query}\n\n Result:\n{rslt}')
        return f"{self.aggregator.aggregate(task_docs)}\n\n{self.aggregator.aggregate(action_docs)}"


