from RoboMemory.BaseModules.BaseMemory import BaseMemory
from RoboMemory.BaseModules.agent_general import GeneralAsyncAgent
from RoboMemory.BaseModules.BaseAggregator import BaseAggregator
import yaml
import os
import logging
class TemporalShortTermMemoryAggregator(BaseAggregator):
    def __init__(self):
        super().__init__()
        
    def aggregate(self, info: list) -> str:
      
        info_string = ""
        for i, step_info in enumerate(info):
            info_string += f"{step_info}\n"
            
        return info_string
class BufferSummarizer(GeneralAsyncAgent):
    def __init__(self, model_config, template_path):
        super().__init__(model_config, template_path)
        
    async def summarizer(self, params, image_paths, base64_image, image_type):
        return self.async_create_completion(params, image_paths, base64_image, image_type)


class TemporalShortTermMemory(BaseMemory):
    
    def __init__(
            self, 
            updater : BufferSummarizer, 
            aggregator : TemporalShortTermMemoryAggregator, 
            memory_path : os.PathLike,
            storage_name = "short_term_memory.yaml",
            max_capacity = 10,
            
            #     
            restore = True
        ) -> None:
      
        super().__init__(updater, aggregator, memory_path)
        self.memory_buffer = []
        self.max_capacity = max_capacity + 1 #         summarize  
        self.working_memory = None 
        
        #   file
        self.restore = restore
        self.storage_file = os.path.join(self.memory_path, storage_name)
        
        if self.restore:
            self.load()
        
        
        
    def save(self):
      
        save_json = {
            "memory_buffer" : self.memory_buffer,
            "working_memory" : self.working_memory
        }
        
        yaml_str = yaml.dump(
            save_json,
            indent=2,
            allow_unicode=True,  #     
            default_flow_style=False,  #         
            sort_keys=False       #         
        )
        
        try:
            with open(self.storage_file, 'w', encoding='utf-8') as f:
                f.write(yaml_str)
            return True  #     
        except IOError as e:
            logging.error(f"save error: {e}")
            return False
    
    def load(self) -> bool:
      
        try:
            with open(self.storage_file, 'r', encoding='utf-8') as f:
                memory_dict = yaml.load(f.read())
            
            self.memory_buffer = memory_dict["memory_buffer"]
            self.working_memory = memory_dict["working_memory"]
            
            return True  #     
        except IOError as e:
            logging.error(f"save error: {e}")
            return False
        
        
    
    def need_update(self) -> bool:
      
        return len(self.memory_buffer) >= self.max_capacity #        max_capacity      
    
    def update_working_memory(self, working_memory: str):
 
        self.working_memory = working_memory
    
    async def update(self, infos : dict[str]) -> bool:
       
        
        information = infos['step_summary']
        
        # 1.     (                    )
        self.memory_buffer.append(information)
        
        #             update
        need_update = self.need_update()
        
        # 2.   
        if need_update:
            #      memory buffer             summary      
            info = self.aggregator.aggregate(self.memory_buffer)
            params = {"trajectory": info}
            summary = await self.updater.async_create_completion(params)
            self.memory_buffer = [summary] #   
            
        
        return need_update
    
    
    def retrieve(self, queries : dict[str, str] = None) -> str:
     
        buffer = self.memory_buffer.copy() # list[str]
        if self.working_memory != None:
            buffer = buffer + [self.working_memory] #          memory_buffer
        
        return self.aggregator.aggregate(buffer)