from langchain_openai import ChatOpenAI
import re
import json
import datetime
import os
from typing import Dict, List, Optional, Any
import networkx as nx
import logging
import yaml
from RoboMemory.BaseModules.BaseMemory import BaseMemory
from RoboMemory.BaseModules.agent_general import GeneralAsyncAgent
from RoboMemory.BaseModules.BaseAggregator import BaseAggregator
from RoboMemory.SpatialMemory.vector_db import VectorDB
from RoboMemory.SpatialMemory.graph_db import GraphDB
from RoboMemory.Datas.ModuleLogger import ModuleLogger
from RoboMemory.agent_utils import ModelConfig, VectorDBConfig



class SpatialMemoryAggregator(BaseAggregator):

    
    def __init__(
            self,
            max_relations = 50
        ):
        super().__init__()
        self.max_relations = max_relations
        
    def _format_retrieval_results(self, graph_results: List[List[str]]) -> str:

        result = "\n"
        for i, triplet in enumerate(graph_results[-self.max_relations:], 1):
            result += f"{i}. {triplet[0]} {triplet[1]} {triplet[2]}\n"
        
        return result.strip()

    def aggregate(self, info: list) -> str:

        return self._format_retrieval_results(info)


class SpatialAnalyzer(GeneralAsyncAgent):

    def __init__(self, model_config: ModelConfig, template_path: str, 
                 integration_template_path: Optional[str] = None):

        super().__init__(model_config, template_path)
        self.integration_template_path = integration_template_path
        
        #       
        if integration_template_path:
            with open(integration_template_path, "r", encoding='utf-8') as fp:
                self.integration_template = fp.read()
        else:
            self.integration_template = None
    
    async def extract_triplets(
            self, 
            aggregated_info: str, 
            possible_objects : str # additional object information
        ) -> List[List[str]]:
      
        params = {
            "aggregated_info": aggregated_info,
            "possible_objects": possible_objects
        }
        
        try:
            response = await self.async_create_completion(params)
            triplets = self._parse_triplets_response(response)
            
            #       
            if not triplets:
               
                pass
            
            return triplets
        except Exception as e:
          
            return []

    async def integrate_and_extract(
            self, 
            entities: List[str], 
            relationships: List[List[str]], 
            triplets: List[List[str]], 
            step_summary: str, 
            knowledge_content: str, 
            possible_objects : str # additional object infomation
        ) -> Dict[str, Any]:
    
        
        #                
        entities_str = "\n".join([f"- {entity}" for entity in entities]) if entities else "no relation entities"
        relationships_str = "\n".join([f"- {rel[0]} --[{rel[1]}]--> {rel[2]}" for rel in relationships]) if relationships else "no relation relationships"
        new_realtionships_str = "\n".join([f"- {rel[0]} --[{rel[1]}]--> {rel[2]}" for rel in triplets]) if triplets else "no new relationships"

        params = {
            "vector_search_entities": entities_str,
            "graph_search_relationships": relationships_str,
            "new_relationships": new_realtionships_str,
            "this_step_summary": step_summary,
            "knowledge_content": knowledge_content,
            "possible_objects": possible_objects
        }

        try:
            old_template = self.template
            self.template = self.integration_template
            
            response = await self.async_create_completion(params)
            
            self.template = old_template
            
            optimization_result = self._parse_optimization_response(response)
            return optimization_result
        except Exception as e:
            import traceback
            traceback.print_exc()
            self.template = old_template
            return {"new_graph": []}
    
    def _parse_triplets_response(self, response: str) -> List[List[str]]:

        try:
            response = response.strip()
            
            #      markdown      JSON
            if response.startswith('```json') and response.endswith('```'):
                response = response[7:-3].strip()  #   ```json ```
            elif response.startswith('```') and response.endswith('```'):
                response = response[3:-3].strip()  #   ```
            
            triplet_objects = json.loads(response)
            
            valid_triplets = []
            if isinstance(triplet_objects, list):
                for obj in triplet_objects:
                    if isinstance(obj, dict) and all(key in obj for key in ['node_a', 'relation', 'node_b']):
                        node_a = str(obj['node_a']).strip()
                        relation = str(obj['relation']).strip()
                        node_b = str(obj['node_b']).strip()
                        
                        if node_a and relation and node_b:
                            valid_triplets.append([node_a, relation, node_b])
            
            return valid_triplets
        except json.JSONDecodeError as e:
           
            return []
        except Exception as e:
      
            return []

    def _parse_optimization_response(self, response: str) -> Dict[str, Any]:

        try:
            response = response.strip()
            #      markdown      JSON
            if response.startswith('```json') and response.endswith('```'):
                response = response[7:-3].strip()  #   ```json ```
            elif response.startswith('```') and response.endswith('```'):
                response = response[3:-3].strip()  #   ```
            
            result = json.loads(response)

            #      new_graph    prompt        
            if isinstance(result, dict) and "new_graph" in result:
                new_relationships = []
                new_graph_content = result["new_graph"]

                if isinstance(result["new_graph"], list):
                    for i, rel_obj in enumerate(result["new_graph"]):
                        if isinstance(rel_obj, dict) and all(key in rel_obj for key in ['node_a', 'relation', 'node_b']):
                            node_a = str(rel_obj['node_a']).strip()
                            relation = str(rel_obj['relation']).strip()
                            node_b = str(rel_obj['node_b']).strip()
                            
                            if node_a and relation and node_b:
                                new_relationships.append([node_a, relation, node_b])

                return {
                    "new_graph": new_relationships
                }
            else:
                return {"new_graph": []}
                
        except json.JSONDecodeError as e:
            return {"new_graph": []}
        except Exception as e:
            return {"new_graph": []}


class SpatialMemory(BaseMemory):
    

    def __init__(
        self, 
        updater: SpatialAnalyzer,
        aggregator: SpatialMemoryAggregator,
        embedding_config: ModelConfig,
        vectordb_config: VectorDBConfig,
        logging_path : str,
        graph_db_path: str = "./database/graphDB",
        graph_db_name: str = "spatial_graph.pkl",
        memory_path: str = None,
        enable_forgetting: bool = False,
        enable_visualization: bool = False,
        visualization_config: Optional[Dict[str, Any]] = None,
        graph_search_top_k: int = 5,
        graph_search_adj_degree: int = 4,
    ) -> None:
   
        super().__init__(updater, aggregator, memory_path)
        self.updater: SpatialAnalyzer = updater
        self.aggregator: SpatialMemoryAggregator = aggregator
        
        self.memory_buffer = []

        self.graph_search_top_k = graph_search_top_k
        self.graph_search_adj_degree = graph_search_adj_degree
        
        #          
        self.retrieval_buffer = {
            'query': None,
            'vector_results': None,
            'graph_results': None
        }
        
        #      
        self.vector_db = VectorDB(
            embedding_conf=embedding_config,
            vectordb_conf=vectordb_config
        )
        
        #     
        os.makedirs(graph_db_path, exist_ok=True)
        self.graph_db = GraphDB(
            db_path=graph_db_path,
            db_name=graph_db_name,
            enable_forgetting=enable_forgetting
        )
        
        self.logger = ModuleLogger(ckpt_path=logging_path, record_name = "spatial_memory.txt")
        
        self.local_step = 0 #     retrieve     step
        
        #        
        self._load_existing_graph()
        

    def _load_existing_graph(self):
        try:
            if os.path.exists(self.graph_db.db_path):
                import pickle
                with open(self.graph_db.db_path, "rb") as f:
                    self.graph_db.graph = pickle.load(f)
            else:
                self.graph_db.graph = nx.DiGraph()
        except Exception as e:
            self.graph_db.graph = nx.DiGraph()
            
    def __aggregate(self, info: list) -> str:
  
        if not info:
            return ""
        
        #        
        aggregated_info = ""
        for i, action in enumerate(info): # No 'i' needed here!
            aggregated_info += f"{action}\n"
        return aggregated_info.strip()
        

    def need_update(self) -> bool:

        return True
    

        
    def __log_info(self, infos : dict, graph : list[list[str]] = None) -> None:
  
        log_info = json.dumps( 
                infos,# objects   
                indent=2,
                ensure_ascii=False, #     /Unicode     allow_unicode 
                sort_keys=False    #         
            )
        self.logger.log(log_info, self.local_step)
    
    async def update(self, infos: Dict[str, Any]) -> List[bool]:

        
        log_memory_buffer = ['No memory buffer'] #       memory buffer
        
        try:
            results = []
            update_success = False  #        
            triplets = []  #         
            
            #       
            if isinstance(infos, str):
                step_summary = infos
            elif isinstance(infos, dict):
                if 'step_summary' in infos:
                    step_summary = infos['step_summary']
                if 'possible_objects' in infos:
                    possible_objects = infos['possible_objects']
            else:
                return [False]
            
            if not step_summary:
                return [False]

            #         
            self.memory_buffer.append(step_summary)
            
            if self.need_update():

                #          
                
                log_memory_buffer = self.memory_buffer.copy()
                
                all_info = self.__aggregate(self.memory_buffer)

                triplets = await self.updater.extract_triplets(aggregated_info = all_info, possible_objects = possible_objects)
                
                if triplets:
                    vector_success = self._store_entities_to_vector_db(triplets)
                    #                   
                    node_success = self._sotre_node_to_graph_db(triplets)
                    # graph_success = self._store_triplets_to_graph_db(triplets)

                    self.memory_buffer.clear()
                    update_success = vector_success and node_success
                    
                    #         
                    update_status = {
                        "vector_success": vector_success,
                        "node_success": node_success,
                        "update_success": update_success,
                        "triplets_count": len(triplets)
                    }

                    #                                 
                    if update_success and self.retrieval_buffer['query']:
                        #           vector_results       
                        new_entities = []
                        for triplet in triplets:
                                #              
                                entity1 = self._standardize_entity_name(triplet[0])
                                entity2 = self._standardize_entity_name(triplet[2])
                                new_entities.append(entity1)  #      1
                                new_entities.append(entity2)  #      2
                        
                        #    vector_results            
                        if self.retrieval_buffer['vector_results'] is None:
                            self.retrieval_buffer['vector_results'] = []
                        
                        #                     
                        existing_entities = set([entity.strip().casefold() for entity in self.retrieval_buffer['vector_results']])
                        for entity in set(new_entities):
                            if entity not in existing_entities:
                                self.retrieval_buffer['vector_results'].append(entity)
                        
                        #    graph_results            
                        if self.retrieval_buffer['graph_results'] is None:
                            self.retrieval_buffer['graph_results'] = []
                        
                    results.append(update_success)
                else:
                    update_success = False
                    update_status = {
                        "update_success": False,
                        "reason": "no_triplets_extracted",
                        "triplets_count": 0
                    }
                    results.append(False)
            else:
                update_success = True  #          
                update_status = {
                    "update_success": True,
                    "reason": "no_update_needed",
                    "memory_buffer_empty": len(self.memory_buffer) == 0
                }
                results.append(True)
            
            #                     
            if self.retrieval_buffer['query'] and self.updater.integration_template_path:
                optimization_result = await self._perform_optimization(triplets, all_info, possible_objects = possible_objects) #       
                results.append(optimization_result)
                
                #             
                self.clear_retrieval_buffer()
            elif self.retrieval_buffer['query'] is None and triplets is not None:
                graph_success = self._store_triplets_to_graph_db(triplets)

            # log
            log_infos = {
                "memory_buffer": log_memory_buffer,
                "triplets": triplets,
                "update_status": update_status if 'update_status' in locals() else None
            }
            
            self.__log_info(log_infos)
            
            #    after_update          
            if self.visualizer and update_success:
                #   after_update   
                filepath = self.visualizer.visualize_after_update(
                    spatial_memory=self
                )

            return results
                
        except Exception as e:
            import traceback
            traceback.print_exc()
            return [False]
    
    def _standardize_entity_name(self, entity: str) -> str:
  
        entity = entity.strip()
        if entity.casefold() == 'i':
            return 'I'
        return entity.casefold()

    def _store_entities_to_vector_db(self, triplets: List[List[str]]) -> bool:
        
        try:
            entities = set()
            for triplet in triplets:
                #           'i'     'I' 
                entity1 = self._standardize_entity_name(triplet[0])
                entity2 = self._standardize_entity_name(triplet[2])
                entities.add(entity1)  #   1
                entities.add(entity2)  #   2
            
            success_count = 0
            failed_entities = []
            
            for entity in entities:
                try:
                    result = self.vector_db.insert_message2DB(entity)
                    if result:
                        success_count += 1
                    else:
                        failed_entities.append(entity)
                except Exception as e:
                    failed_entities.append(entity)
                    #           
                    self.__log_info({
                        "vector_db_entity_error": entity,
                        "error": str(e)
                    })
            
            #                
            if failed_entities:
                self.__log_info({
                    "vector_db_summary": {
                        "success_count": success_count,
                        "failed_count": len(failed_entities),
                        "failed_entities": failed_entities
                    }
                })
            
            return success_count > 0
            
        except Exception as e:
            #            
            self.__log_info({
                "vector_db_critical_error": str(e),
                "operation": "_store_entities_to_vector_db"
            })
            return False
        
    def _sotre_node_to_graph_db(self, nodes: List[List[str]]) -> bool:

        try:
            if not nodes:
                return False
            
            for node in nodes:
                #           'i'     'I'            
                node_a = self._standardize_entity_name(node[0])
                node_b = self._standardize_entity_name(node[2])
                
                if node_a not in self.graph_db.graph.nodes:
                    self.graph_db.add_node(node_a)
                if node_b not in self.graph_db.graph.nodes:
                    self.graph_db.add_node(node_b)

            return True
        except Exception as e:
            #             
            self.__log_info({
                "graph_db_node_error": str(e),
                "operation": "_sotre_node_to_graph_db"
            })
            return False
    
    def _store_triplets_to_graph_db(self, triplets: List[List[str]]) -> bool:
  
        try:
            if not triplets:
                return False
            
            sub_graph_dict = []
            for triplet in triplets:
                if len(triplet) == 3:
                    entity1, relation, entity2 = triplet
                    #           'i'     'I'            
                    standardized_entity1 = self._standardize_entity_name(entity1)
                    standardized_entity2 = self._standardize_entity_name(entity2)
                    
                    standardized_triplet = {
                        "node_a": standardized_entity1,  #           
                        "relation": relation.strip().casefold(),
                        "node_b": standardized_entity2   #           
                    }
                    sub_graph_dict.append(standardized_triplet)
                    

            if sub_graph_dict:
                sub_graph = self.graph_db.create_sub_graph(sub_graph_dict)
                self.graph_db.merge_graph(sub_graph, save_graph=True)
                return True
            else:
                return False
                
        except Exception as e:
         
            return False
    
    async def _perform_optimization(
            self, 
            triplets: List[List[str]],
            step_summary : str, 
            possible_objects : str # additional object input
        ) -> bool:
 
        try:

            #       retrieve              
            query = self.retrieval_buffer['query']
            vector_results = self.retrieval_buffer['vector_results'] or []
            graph_results = self.retrieval_buffer['graph_results'] or []
            
            #           (                                         buffer      (           ))
            current_graph_results = self._get_enhanced_graph_results(
                vector_results, graph_results
            )
            #    current_graph            ** **  

            #                         
            self.graph_db.delete_relationships(current_graph_results)
            
            # log      
            pre_optimization_log = {
                "optimization_query": query,
                "vector_results": vector_results,
                "original_graph_results": graph_results,
                "enhanced_graph_results": current_graph_results,
                "enhanced_relationships_count": len(current_graph_results),
                "all_extracted_triplets": triplets,
                "optimization_step": "pre_optimization"
            }
            self.__log_info(pre_optimization_log)
            #           LLM    
            knowledge_content = ""
            try:
                with open("RoboMemory/templates/Knowledge/knowledge.txt", 'r', encoding='utf-8') as f:
                    knowledge_content = f.read()
            except Exception as e:
                #                     
                knowledge_content = ""
            optimization_result = await self.updater.integrate_and_extract(
                vector_results, current_graph_results, triplets, step_summary, knowledge_content, possible_objects = possible_objects
            )
            
            # log         
            optimization_log = {
                "new_graph": optimization_result,
                "new_graph_step": "gen new graph"
            }
            self.__log_info(optimization_log)
            
            #       
            optimization_applied = self._apply_optimization_suggestions(
                optimization_result
            )
            
            #    after_optimization           
            if self.visualizer and optimization_applied:
                #             
                self._last_optimization_result = optimization_result
                self._last_optimization_summary = {
                    "optimization_applied": optimization_applied,
                    "step": self.local_step
                }
                
                filepath = self.visualizer.visualize_after_optimization( #    optimiz           
                    current_graph_results=current_graph_results,
                )

            return optimization_applied
            
        except Exception as e:
            import traceback
            traceback.print_exc()
            return False
    
    def _get_enhanced_graph_results(self, vector_results: List[str],
                                  graph_results: List[List[str]]) -> List[List[str]]:
  
        try:
            #           
            relevant_nodes = set()
            
            # 1.   vector_results         
            for entity in vector_results:
                standardized_entity = self._standardize_entity_name(entity.strip())
                relevant_nodes.add(standardized_entity)
            
            # 2.   graph_results             
            for triplet in graph_results:
                if len(triplet) >= 3:
                    node_a = self._standardize_entity_name(triplet[0].strip())
                    node_b = self._standardize_entity_name(triplet[2].strip())
                    relevant_nodes.add(node_a)  # node_a
                    relevant_nodes.add(node_b)  # node_b
            
            # # 3.                   >2   
            # query_words = query.casefold().split()
            # query_entities = []
            # for word in query_words:
            #     if len(word) > 2:
            #         relevant_nodes.add(word)
            #         query_entities.append(word)
            
            # log       
            node_collection_log = {
                "operation": "collect_relevant_nodes",
                "relevant_nodes": list(relevant_nodes),
                "nodes_from_vector": len(vector_results),
                "nodes_from_graph": len(set([t[0] for t in graph_results] + [t[2] for t in graph_results if len(t) >= 3])),
                "total_relevant_nodes": len(relevant_nodes)
            }
            self.__log_info(node_collection_log)
            
            # 4.              
            #                   
            existing_nodes = set()
            for node in relevant_nodes:
                if node in self.graph_db.graph.nodes:
                    existing_nodes.add(node)
            
            if not existing_nodes:
                no_nodes_log = {
                    "graph_enhancement_result": "no_existing_nodes",
                    "returning_original_results": True
                }
                self.__log_info(no_nodes_log)
                return graph_results.copy()
            
            # 5.            
            subgraph = self.graph_db.graph.subgraph(existing_nodes)
            
            # 6.              
            enhanced_relationships = []
            
            #      graph_results
            existing_relationships = set()
            for rel in graph_results:
                enhanced_relationships.append(rel)
                existing_relationships.add(tuple(rel))
            
            #          
            new_relationships = []
            for u, v, data in subgraph.edges(data=True):
                relation = data.get('relation', 'unknown')
                relationship = [u, relation, v]
                
                #       
                if tuple(relationship) not in existing_relationships:
                    enhanced_relationships.append(relationship)
                    existing_relationships.add(tuple(relationship))
                    new_relationships.append(relationship)
            
            # log      
            enhancement_complete_log = {
                "graph_enhancement_complete": True,
                "subgraph_nodes": len(subgraph.nodes),
                "subgraph_edges": len(subgraph.edges),
                "original_relationships": len(graph_results),
                "new_relationships_found": len(new_relationships),
                "new_relationships": new_relationships,
                "total_enhanced_relationships": len(enhanced_relationships)
            }
            self.__log_info(enhancement_complete_log)
            
            return enhanced_relationships
            
        except Exception as e:
            #                
            import traceback
            traceback.print_exc()
            return graph_results.copy()

    def _apply_optimization_suggestions(self, optimization_result: Dict[str, Any]) -> bool:

        
        relationship_deleted = False  #           
        
        try:

            if optimization_result.get("new_graph"):
                new_graph = optimization_result["new_graph"]

                #            
                if new_graph:
                    #  List[List[str]]     create_sub_graph   list[dict[str, str]]  
                    sub_graph_dict = []
                    for triplet in new_graph:
                        if len(triplet) == 3:
                            entity1, relation, entity2 = triplet
                            #        
                            standardized_entity1 = self._standardize_entity_name(entity1)
                            standardized_entity2 = self._standardize_entity_name(entity2)
                            
                            standardized_triplet = {
                                "node_a": standardized_entity1,
                                "relation": relation.strip().casefold(),
                                "node_b": standardized_entity2
                            }
                            sub_graph_dict.append(standardized_triplet)
                    
                    if sub_graph_dict:
                        #              
                        sub_graph = self.graph_db.create_sub_graph(sub_graph_dict)
                        self.graph_db.merge_graph(sub_graph, save_graph=True)

            #            
            isolated_nodes = self.graph_db.delete_isolated_nodes()
            return True
            
        except Exception as e:
            import traceback
            traceback.print_exc()
            return False
    
    def clear_retrieval_buffer(self) -> bool:

        try:
            self.retrieval_buffer = {
                'query': None,
                'vector_results': None,
                'graph_results': None
            }
            return True
        except Exception as e:
            return False

    def get_memory_stats(self) -> Dict[str, Any]:
    
        stats = {
            #       
            "graph_nodes": self.graph_db.graph.number_of_nodes(),
            "graph_edges": self.graph_db.graph.number_of_edges(),
            "memory_buffer_size": len(self.memory_buffer),
            
            #     
            "retrieval_buffer_has_data": bool(self.retrieval_buffer['query']),
            "retrieval_buffer_vector_count": len(self.retrieval_buffer['vector_results']) if self.retrieval_buffer['vector_results'] else 0,
            "retrieval_buffer_graph_count": len(self.retrieval_buffer['graph_results']) if self.retrieval_buffer['graph_results'] else 0
        }
        return stats

    def save(self):
      
        try:
            if not self.memory_path:
                self.memory_path = "./database/unified_memory_data.json"
            
            #       
            os.makedirs(os.path.dirname(self.memory_path), exist_ok=True)
            
            #         
            save_data = {
                "memory_buffer": self.memory_buffer,
                "retrieval_buffer": self.retrieval_buffer,
                "memory_stats": self.get_memory_stats(),
                "timestamp": datetime.datetime.now().isoformat()
            }
            
            #    JSON  
            with open(self.memory_path, 'w', encoding='utf-8') as f:
                json.dump(save_data, f, ensure_ascii=False, indent=2)
            
            #            
            self.graph_db._save_graph()
            
            return True
            
        except Exception as e:
            return False

    def load(self):
      
        try:
            if not self.memory_path:
                self.memory_path = "./database/unified_memory_data.json"
            
            #         
            if not os.path.exists(self.memory_path):
                return False
            
            #  JSON      
            with open(self.memory_path, 'r', encoding='utf-8') as f:
                load_data = json.load(f)
            
            #      
            self.memory_buffer = load_data.get("memory_buffer", [])
            
            #        
            retrieval_buffer = load_data.get("retrieval_buffer", {})
            self.retrieval_buffer = {
                'query': retrieval_buffer.get('query', None),
                'vector_results': retrieval_buffer.get('vector_results', None),
                'graph_results': retrieval_buffer.get('graph_results', None)
            }
            
            self._load_existing_graph()
            return True
            
        except Exception as e:
            return False

    # ===    ===
    def get_forgetting_stats(self) -> Dict[str, Any]:
  
        return self.graph_db.get_forgetting_stats()
    
    def manual_forgetting(self) -> Dict[str, Any]:

        graph_result = self.graph_db.manual_forgetting()
        
        #                         
        if (graph_result.get("success") and 
            graph_result.get("removed_node_list")):
            
            removed_nodes = graph_result["removed_node_list"]
            
            #                
            try:
                vector_delete_success = self.vector_db.delete_entities(removed_nodes)
                
                #                
                graph_result["vector_db_deletion"] = {
                    "attempted": len(removed_nodes),
                    "success": vector_delete_success,
                    "deleted_entities": removed_nodes if vector_delete_success else []
                }
                    
            except Exception as e:
                graph_result["vector_db_deletion"] = {
                    "attempted": len(removed_nodes),
                    "success": False,
                    "error": str(e)
                }
        else:
            #                 
            graph_result["vector_db_deletion"] = {
                "attempted": 0,
                "success": True,
                "deleted_entities": []
            }
        
        return graph_result
    
    def get_comprehensive_stats(self) -> Dict[str, Any]:
    
        basic_stats = self.get_memory_stats()
        spatial_forgetting_stats = self.get_forgetting_stats()
        
        #     
        comprehensive_stats = {
            **basic_stats,
            "spatial_forgetting_algorithm": spatial_forgetting_stats,
            "vector_db_count": getattr(self.vector_db, 'get_collection_count', lambda: 0)(),
        }
        
        return comprehensive_stats

    @property
    def spatial_vector_db(self):
        return self.vector_db
    
    @property 
    def spatial_graph_db(self):
        return self.graph_db

    def _get_queries(self, queries: list) -> str:

        query_text = ""
        for query in queries:
            query_text += query + "\n"

        return query_text.strip()

    def retrieve(self, queries: Dict[str, list]) -> str:
  
        self.local_step += 1 #     retrieve       step
        
        #                     
        vector_results = []
        graph_results = []
        query_text = ""
        
        try:
            if self.name in queries and queries[self.name]:
                # query_text = queries['query']
                query_text = self._get_queries(queries[self.name])

                #     
                vector_results = self.vector_db.search_DB(query_text, k=self.graph_search_top_k)

                if vector_results:
                    before_access_count = self.graph_db.access_count
                    graph_results = self.graph_db.search_graph(vector_results, adj_degree=self.graph_search_adj_degree)

                    #             
                    if (self.graph_db.enable_forgetting and 
                        self.graph_db.access_count % 10 == 0 and
                        self.graph_db.access_count != before_access_count):
                        self._sync_vector_db_cleanup()

                #           
                self.retrieval_buffer = {
                    'query': query_text,
                    'vector_results': vector_results,
                    'graph_results': graph_results
                }

                #    retrieval_buffer      
                if self.visualizer:
                    filepath = self.visualizer.visualize_retrieval_buffer(
                        graph_results=graph_results
                    )

                # log
                log_infos = {
                    "query_text": query_text,
                    "vector_results": vector_results,
                    "graph_results": graph_results
                }
                
                self.__log_info(log_infos)
                
                #           
                if graph_results:
                    return self.aggregator.aggregate(graph_results)
                else:
                    return "No relevant relationships found."
            
            #            
            return "No relevant memory information."
            
        except Exception as e:
     
            import traceback
            traceback.print_exc()
            return f"retrieve memory failed: {e}"

    def _sync_vector_db_cleanup(self):
     
        try:
            #             
            graph_nodes = set(self.graph_db.graph.nodes())
            
            #              
            vector_entities = set(self.vector_db.message_set.keys())
            
            #                     
            orphaned_entities = vector_entities - graph_nodes
            
            if orphaned_entities:
                orphaned_list = list(orphaned_entities)
                delete_success = self.vector_db.delete_entities(orphaned_list)

        except Exception as e:
          
            print(f"clean up vector db failed: {e}")