from abc import abstractmethod
import logging

from regex import B
from RoboMemory.BaseModules.BaseMemory import BaseMemory
from RoboMemory.BaseModules.agent_general import GeneralAsyncAgent
from RoboMemory.BaseModules.BaseAggregator import BaseAggregator
from RoboMemory.SpatialMemory.vector_db import DashScopeEmbeddings
from RoboMemory.agent_utils import ModelConfig, VectorDBConfig
from langchain_core.documents import Document
from typing import Any
from RoboMemory.Datas.ModuleLogger import ModuleLogger
from langchain_chroma import Chroma

def get_document_count(vs: Chroma) -> int:

    try:
        count = vs._collection.count()
        # print(f"Total documents in vector store: {count}")
        return count
    except Exception as e:
        logging.warning(f"Error getting document count: {e}")
        return 0

class LTMAggregator(BaseAggregator):

    def __init__(self):
        super().__init__()
    
    @abstractmethod
    def aggregate(self, info) -> str:
        raise NotImplementedError("aggregate method must be implemented in subclasses")

class Extractor(GeneralAsyncAgent):

    def __init__(self, model_config, template_path = "RoboMemory/templates/prompts/LTM_Prompts/MemExtractor.prompt"):
        super().__init__(model_config, template_path)

# abstract class
class MemoryGenerator(GeneralAsyncAgent):
   
    def __init__(self, model_config, template_path):
        super().__init__(model_config, template_path)
    
    @abstractmethod
    async def generate(self, e:str):

        raise NotImplementedError("generate method must be implemented in subclasses")


class LongTermMemory(BaseMemory):

    _vectorstore : Chroma
    _extractor: Extractor
    _recent_mem: list[str]
    _memory_generator: MemoryGenerator
    _num_iter: int
    _retrieve_k_task: int
    logger: ModuleLogger

    def __init__(self, 
                 updater : GeneralAsyncAgent, 
                 aggregator : BaseAggregator, 
                 extractor: Extractor,
                 memory_generator: MemoryGenerator,
                 embedding_config: ModelConfig,
                 vectordb_config: VectorDBConfig,
                 retrieve_k_task: int = 5
                ):
        super().__init__(updater, aggregator, memory_path=vectordb_config.persist_directory)
        self._extractor = extractor
        self._memory_generator = memory_generator
        self._recent_mem = ['']
        self._retrieval_k_task = retrieve_k_task  #          
        self._num_iter = 0

        embedding = DashScopeEmbeddings(
            api_key=embedding_config.api_key,
            model=embedding_config.model,
            base_url=embedding_config.base_url
        )

        self._vectorstore = Chroma(
            collection_name = vectordb_config.collection_name,
            embedding_function = embedding,
            persist_directory = vectordb_config.persist_directory,
        )

    # extraction phase
    @abstractmethod       
    def update(self, infos: dict):
  
        raise NotImplementedError("update method is not implemented in this memory!")

    # update phase
    @abstractmethod
    def _create_entry(self, mem_str, action = '', task = '') -> Document:
   
        raise NotImplementedError("This method should be implemented in subclasses")
    
    @abstractmethod
    def retrieve(self, queries:dict) -> str:
   
        raise NotImplementedError("retrieve method is not implemented in this memory!")
    
    def save(self):
        
        self._num_iter = 0
        self._recent_mem = ['']
        # self._vectorstore.persist()  #         

    def load(self):
        
        pass
    
    def need_update(self) -> bool:
        return True