

from RoboMemory.Modules.InfoProcessor.StepSumarizer import StepSumGenerator
from RoboMemory.Modules.InfoProcessor.QueryGenerator import QueryGenerator
from typing import Literal, Any
from RoboMemory.agent_utils import ModelConfig, yaml_decoder
import asyncio
import yaml

class InfoProcessor:
    def __init__(
            self,
            step_summarizer_config : ModelConfig,
            step_summarizer_prompt : str,
            spatial_query_generator_config : ModelConfig,
            spatial_query_generator_prompt : str,
            semantic_query_generator_config : ModelConfig,
            semantic_query_generator_prompt : str,
            episodic_query_generator_config : ModelConfig,
            episodic_query_generator_prompt : str,
        ):
   
        self.step_sum_generator = StepSumGenerator(step_summarizer_config, step_summarizer_prompt)
        self.spatial_query_generator = QueryGenerator(spatial_query_generator_config, spatial_query_generator_prompt)
        self.semantic_query_generator = QueryGenerator(semantic_query_generator_config, semantic_query_generator_prompt)
        self.episodic_query_generator = QueryGenerator(episodic_query_generator_config, episodic_query_generator_prompt)
        
    def __query_preprocess(self, return_str : str) -> list[str]:
        
        output = yaml_decoder(return_str)
        
        queries : list[str] = yaml.safe_load(output)
        
        return queries['queries']
    
    async def generate_querys(
            self, 
            query_params : dict[str], 
            image_paths : list|str = None,
            base64_image : bool = True,
            image_type : Literal["jpeg", "png", "webp", "gif"] = "jpeg"
        ):
    
        
        task_list = [
            self.spatial_query_generator.async_create_completion(query_params, image_paths, base64_image, image_type),
            self.semantic_query_generator.async_create_completion(query_params, image_paths, base64_image, image_type),
            self.episodic_query_generator.async_create_completion(query_params, image_paths, base64_image, image_type)
        ]
        
        
        results = await asyncio.gather(
            *task_list
        )
        
        return self.__query_preprocess(results[0]), self.__query_preprocess(results[1]), self.__query_preprocess(results[2])
        
    
    async def generate_infos(
            self, 
            query_params : dict[str], 
            step_sum_params : dict[str],
            image_paths : list|str = None,
            base64_image : bool = True,
            image_type : Literal["jpeg", "png", "webp", "gif"] = "jpeg"
        ) -> tuple:
     
        task_list = [
            self.step_sum_generator.async_create_completion(step_sum_params, image_paths, base64_image, image_type),
            self.spatial_query_generator.async_create_completion(query_params, image_paths, base64_image, image_type),
            self.semantic_query_generator.async_create_completion(query_params, image_paths, base64_image, image_type),
            self.episodic_query_generator.async_create_completion(query_params, image_paths, base64_image, image_type)
        ]
        
        
        results = await asyncio.gather(
            *task_list
        )
        
        return results[0], self.__query_preprocess(results[1]), self.__query_preprocess(results[2]), self.__query_preprocess(results[3])
        