from langchain_openai import ChatOpenAI
import yaml
from RoboMemory import agent_utils
from RoboMemory.agent_utils import yaml_decoder, decode_cot
from RoboMemory.BaseModules.agent_general import GeneralAsyncAgent
from typing import Dict, Any, List, Literal
import re
import logging

def clean_yaml_text(text):

    #                  
    pattern = re.compile(
        r'^\s*(Feedback\s*:|Action\s+Suitability\s*:)',
        flags=re.MULTILINE
    )
    
    skip_indices = set()
    #                     
    for match in pattern.finditer(text):
        start, end = match.span()
        #              
        colon_index = end - 1
        skip_indices.add(colon_index)
    
    #         
    cleaned_chars = []
    for i, char in enumerate(text):
        if char == ':':
            cleaned_chars.append(':' if i in skip_indices else ';')
        else:
            cleaned_chars.append(char)
    
    return ''.join(cleaned_chars)
class Critic(GeneralAsyncAgent):
 
            
    async def get_feed_back(
            self, 
            params : Dict[str, Any] = None,
            image_paths : list|str = None,
            base64_image : bool = True,
            image_type : Literal["jpeg", "png", "webp", "gif"] = "jpeg"
        ) -> dict:
 
        return_str = await self.async_create_completion(params, image_paths, base64_image, image_type)
        try:
            yaml_str = yaml_decoder(return_str)
            yaml_str = clean_yaml_text(yaml_str)
            CoT = decode_cot(return_str)
        except:
            yaml_str = return_str #            
            yaml_str = clean_yaml_text(yaml_str)
            CoT = ""
        
        load_dict =  yaml.safe_load(yaml_str)
        
        logging.info(f"\n##### \n{yaml_str}\n #####\n")
        
        return bool(load_dict["Action Suitability"]), load_dict["Feedback"], params, CoT, return_str