from typing import List,Any,Dict
import re
import asyncio
import json
from HeGFlow.graph.node import Node
from HeGFlow.agents.agent_registry import AgentRegistry
from HeGFlow.llm.llm_registry import LLMRegistry
from HeGFlow.prompt.prompt_set_registry import PromptSetRegistry
from HeGFlow.tools.search.Serper import search_serper_main
from HeGFlow.tools.General.general_tools import str_to_list
from HeGFlow.graph.tool_nodes.coder import Coder
from openai import OpenAI


def find_strings_between_pluses(text):
    return re.findall(r'\@(.*?)\@', text)

@AgentRegistry.register('AgentGraph_gsm8k')
class AgentGraph_gsm8k(Node):
    def __init__(self, id: str | None =None, role:str = None,  domain: str = "", llm_name: str = "",):
        super().__init__(id, "AgentGraph_gsm8k" ,domain, llm_name)
        self.llm = LLMRegistry.get(llm_name)
        self.prompt_set = PromptSetRegistry.get(domain)
        self.role = self.prompt_set.get_role() if role is None else role
   
        self.constraint = self.prompt_set.get_analyze_constraint(self.role)
        self.critic_enabled = False
        #print(self.constraint)
        
        
        
    async def _process_inputs(self, raw_inputs:Dict[str,str], spatial_info:Dict[str,Dict], temporal_info:Dict[str,Dict], **kwargs)->List[Any]:
        """ To be overriden by the descendant class """
        """ Process the raw_inputs(most of the time is a List[Dict]) """              
        system_prompt = f"{self.constraint}"
        user_prompt = f"The task is: {raw_inputs['task']}\n" if self.role != 'Fake' else self.prompt_set.get_adversarial_answer_prompt(raw_inputs['task'])
        spatial_str = ""
        temporal_str = ""
        for id, info in spatial_info.items():
    
            


            if self.role == 'Tool_search' :
                search = Serper_Searcher()
                user_prompt += await search._execute_tool(raw_inputs["task"])

            elif self.role == 'Tool_coder' :
                coder = Coder()
                user_prompt += coder._execute_tool(raw_inputs["task"])

            elif self.role == 'Tool_caculator' :
                caculator = Calculator()
                user_prompt += caculator._execute_tool(raw_inputs["task"])
            
            elif self.role == 'Tool_pdfreader' :
                pdf_reader = PdfReader()
                user_prompt += pdf_reader._execute_tool(raw_inputs["task"])

            elif self.role == 'Tool_weather' :
                weather = Weather()
                user_prompt += weather._execute_tool(raw_inputs["task"])



            
            if 'None.' in (info['output'] if isinstance(info['output'], list) else [info['output']]):
                continue
            spatial_str += f"Agent {id}, role is {info['role']}, output is:\n\n {info['output']}\n\n"
        for id, info in temporal_info.items():
            if 'None.' in (info['output'] if isinstance(info['output'], list) else [info['output']]):
                continue
            temporal_str += f"Agent {id}, role is {info['role']}, output is:\n\n {info['output']}\n\n"
            
        user_prompt += f"At the same time, the outputs of other agents are as follows:\n\n{spatial_str} \n\n" if len(spatial_str) else ""
        user_prompt += f"In the last round of dialogue, the outputs of other agents were: \n\n{temporal_str}" if len(temporal_str) else ""

        return system_prompt, user_prompt
                
    def _execute(self, input:Dict[str,str],  spatial_info:Dict[str,Dict], temporal_info:Dict[str,Dict],**kwargs):
        """ To be overriden by the descendant class """
        """ Use the processed input to get the result """
        critic_node = kwargs.get("critic_node")
        if not critic_node:
            raise ValueError("CriticNode required for reflection")
        
        system_prompt, user_prompt = self._process_inputs(input, spatial_info, temporal_info)
      
        message = [{'role':'system','content':system_prompt},{'role':'user','content':user_prompt}]
        response = self.llm.gen(message)
        return response


    async def _async_execute(self, input:Dict[str,str], spatial_info:Dict[str,Dict], temporal_info:Dict[str,Dict], critic_enabled:bool,**kwargs):
        """ To be overriden by the descendant class """
        """ Use the processed input to get the result """
        critic_node = kwargs.get("critic_node")
        
        if not critic_node:
            raise ValueError("CriticNode required for reflection") 
        
        
        output_text = None
        reflection_prompt = "You are a {role},   Your previous output is:  {previous_output}  This is the feedback from the CriticNode: {feedback}    Please refine your answer based on the questions and suggestions mentioned in the feedback"
        
        feedback_history = []

        system_prompt, user_prompt = await self._process_inputs(input, spatial_info, temporal_info)
        critic_enabled = critic_enabled
        
        

        if critic_enabled == True:   
            reflection_attempts = 0
            while reflection_attempts < 3:
                try:
                    message = [{'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': user_prompt}]
                    response = await self.llm.agen(message)   
                    output_text = response if response else "None."
                    if output_text == "None.":
                        break
                   
                    critic_input = {
                        "task": input["task"],
                        "role_outputs": {self.role: { self.id: output_text }}
                    }
                    feedback = await critic_node._async_execute_single_node(critic_input, {}, {})
            
                    feedback_dict = next(
                        (f["feedback"] for f in feedback if f["node_id"] == self.id),
                        {"issues": ["No feedback received"], "suggestions": []}
                    )
                    feedback_history.append(feedback_dict)
                    if len(feedback_dict.get("issues", [])) == 0: 
                        break
           
                    system_prompt = system_prompt
              
                    user_prompt = reflection_prompt.format(    
                        role=self.role,
                        previous_output=output_text,
                        feedback=json.dumps(feedback_dict)
                    )
                    reflection_attempts += 1
                  
                except Exception as e:
                    feedback_history.append({
                        "issues": [f"Error during execution: {str(e)}"],
                        "suggestions": ["Retry or check LLM stability"]
                    })
                    output_text = f"Error: {str(e)}"
                    print("CriticNode Error！")
                    reflection_attempts += 1
                    break
            return output_text
           
        else:
            message = [{'role':'system','content':system_prompt},{'role':'user','content':user_prompt}]
            response = await self.llm.agen(message)  
            return response