import os
import re
import json
import importlib
from typing import Dict, List, Optional, Union, AsyncGenerator, Tuple
from copy import deepcopy
import random
from pathlib import Path
import asyncio
from collections import defaultdict
import time

from configs import logger
from utils import load_json, save_json, replace_str, extract_json_dict, get_corresponding_functions, extract_json, is_high_risk_command, encode_media, is_base64_image

from frontierbench.memory.graph import Graph
from frontierbench.tools.mcp_client import MCPClient, SSEMCPClient
from frontierbench.models.openai import generate_with_openai
from prompts.theme_w_parts import parts_prompt_lists, context_refinement_prompt


def post_process_tool_response(srv_name, results_raw, is_success) -> Tuple[str, bool]:
    is_add_user_message = False

    if is_success and srv_name in ['screenshot-website-fast', 'get_image'] and isinstance(results_raw, dict) and 'content' in results_raw.keys():
        temp_results_raw = deepcopy(results_raw)
        for i_content in temp_results_raw['content']:
            if isinstance(i_content, dict) and is_base64_image(i_content):
                i_content['data'] = 'Multimedia data is obtained, please refer to the next user message.'
                del i_content['mimeType']
        result_str = json.dumps(temp_results_raw, ensure_ascii=False)
        is_add_user_message = True
    else:
        result_str = results_raw
        if isinstance(result_str, dict) or isinstance(result_str, list):
            result_str = json.dumps(result_str, ensure_ascii=False)

    return result_str, is_add_user_message

def generate_text(conversation: List[Dict], model_cfg: Dict,
                  all_functions: List[Dict] = []) -> Tuple[Dict, Dict]:
    """
    Generate text using the specified provider.

    Args:
        conversation: The conversation history
        model_cfg: Configuration for the model
        all_functions: Available functions for the model to call

    Returns:
        Dict containing assistant_text and tool_calls
    """
    # provider = model_cfg.get("provider", "")

    response_dict, usage = generate_with_openai(conversation, model_cfg, all_functions)

    return response_dict, usage

async def process_tool_call(tc: Dict, servers: Dict[str, MCPClient], mcp_tool2server: Dict[str, str], return_content_limit: int = 4096) -> Tuple[Dict, Dict | str | None, bool, bool]:
    """Process a single tool call and return the result"""
    func_name = tc["function"]["name"]

    func_args_str = tc["function"].get("arguments", "{}")
    try:
        func_args = json.loads(func_args_str)
    except:
        return {
            "role": "tool",
            "tool_call_id": tc["id"],
            "name": func_name,
            "content": f"Error: Failed to parse function arguments. The provided string is not valid JSON: {func_args_str}."
        }, None, False, False

    # logger.info(f"\nView result from {tool_name} from {srv_name} {json.dumps(func_args, ensure_ascii=False)}")

    mcp_flag = False
    srv_name = func_name
    if func_name in mcp_tool2server.keys():
        mcp_flag = True
        srv_name = mcp_tool2server[func_name]

    if srv_name not in servers.keys():
        return {
            "role": "tool",
            "tool_call_id": tc["id"],
            "name": func_name,
            "content": f"Error: function name {func_name} does not match. Please check for incorrect naming."
        }, None, False, False

    # Get the tool's schema
    if not mcp_flag:
        tool_schema = servers[srv_name].tool_json_schema
    else:
        tool_schema = servers[srv_name].tools[func_name].get("parameters", {})

    if tool_schema:
        required_params = tool_schema.get("required", [])
        missing_params = [param for param in required_params if param not in func_args]
        if missing_params:
            return {
                "role": "tool",
                "tool_call_id": tc["id"],
                "name": func_name,
                "content": f"Error: Missing required parameter {missing_params}. Please check for incorrect naming."
            }, None, False, False

    # shell 执行，拦截高危指令
    if func_name in ['shell_exec']:
        if func_args.get('command', ''):
            is_risk, reason = is_high_risk_command(func_args['command'])
            if is_risk:
                return {
                    "role": "tool",
                    "tool_call_id": tc["id"],
                    "name": func_name,
                    "content": f"Error: High-risk command detected. Reason: {reason}. Please revise your command."
                }, None, False, False

    # 工具执行
    results_raw, is_success = await servers[srv_name].call_tool(tool_name=func_name, arguments=func_args)
    result_str, is_add_user_message = post_process_tool_response(srv_name, results_raw, is_success)

    return {
        "role": "tool",
        "tool_call_id": tc["id"],
        "name": func_name,
        "content": result_str[:return_content_limit]
    }, results_raw, is_success, is_add_user_message

def prepare_act_prompt(act_prompt: str, task_description: str, dependent_task_results: str) -> str:
    step_str_mapping = {
        '__dependent_task_results__': '\nHere is verified information that it relies on and may be useful:\n' + dependent_task_results if len(dependent_task_results) > 0 else '',
        '__task_description__': task_description
    }
    return replace_str(act_prompt, step_str_mapping).strip()

class MCPAgent:
    @classmethod
    async def create(cls, args) -> "MCPAgent":
        """
        Create an instance of the MCPAgent using MCPAgent.create(...)
        async class method so that the initialization can be awaited.

        Args:
            model_name: Name of the model to use (optional)
            provider_config: Provider configuration dictionary (required)
            mcp_server_config: MCP server configuration dict (optional, if not provided will load from mcp_server_config_path)
            mcp_server_config_path: Path to the MCP server configuration file
            log_messages_path: Path to log messages in JSONL format (optional)

        Returns:
            An instance of MCPAgent 
        """

        obj = cls()
        await obj._initialize(
            model_name=args.model_name,
            log_messages_path=args.log_messages_path,
            model_config_path=args.model_config_path,
            mcp_server_config=args.mcp_server_config,
            mcp_server_config_path=args.mcp_server_config_path,
            mcp_server_config_class=args.mcp_server_config_class,
            conversation_path=args.conversation_path,
            gen_part_prompt=args.gen_part_prompt,
            all_deleted_functions_file=args.all_deleted_functions_file
        )
        return obj

    def __init__(self):
        pass

    async def _initialize(self,
                          model_name: str,
                          log_messages_path: str,
                          model_config_path: str,
                          mcp_server_config: dict | None,
                          mcp_server_config_path: str,
                          mcp_server_config_class: str,
                          conversation_path: str,
                          gen_part_prompt: str,
                          all_deleted_functions_file: Optional[str] = None
            ) -> Union[str, AsyncGenerator[str, None]]:
        self.log_messages_path = log_messages_path
        self.conversation_path = conversation_path
        self.gen_part_prompt = gen_part_prompt
        self.all_deleted_functions = load_json(all_deleted_functions_file) if all_deleted_functions_file else {}


        # 1) Load MCP Server config if not provided directly
        if mcp_server_config is None:
            mcp_server_config = load_json(mcp_server_config_path)
            if "mcpServers" in mcp_server_config.keys():
                mcp_server_config = mcp_server_config["mcpServers"]
            elif "Perception" in mcp_server_config.keys() and "Action" in mcp_server_config.keys():
                mcp_server_config = {**mcp_server_config["Perception"], **mcp_server_config["Action"]}
                mcp_server_config = {k: v for d in mcp_server_config.values() for k, v in d.items()}

            if mcp_server_config_class is not None:
                mcp_server_config = {k: v for i_cls in mcp_server_config_class for k, v in mcp_server_config.get(i_cls, {}).items()}

        logger.info(f"Tools configuration: {mcp_server_config}")

        # 2) Choose a model
        self.chosen_model = load_json(model_config_path)[model_name]

        # 3) Start servers
        self.servers = {}
        self.server2configs = defaultdict(list)
        self.mcp_tool2server = {}
        self.server2mcp_tools = {}
        tools_module = importlib.import_module("frontierbench.tools")
        server_set = set()
        for server_name, config in mcp_server_config.items():
            is_local = config.get("is_local", False)
            if is_local:
                tool_class = getattr(tools_module, config.get('tool_class'))
                client = tool_class(**{k: v for k, v in config.items() if k not in ['tool_class', 'is_local']})
                self.server2configs['local'].append(client.tool_json_schema)
                server_key = client.tool_json_schema["name"]
            else:
                server_key = server_name
                if "url" in config:  # SSE server
                    client = SSEMCPClient(server_name, config["url"])
                else:  # Local process-based server
                    client = MCPClient(
                        server_name=server_name,
                        command=config.get("command"),
                        args=config.get("args", []),
                        env=config.get("env", {}),
                        cwd=config.get("cwd", None)
                    )
                is_start = await client.start()
                if not is_start:
                    logger.warning(f"[WARNING] Could not start server {server_name}")
                    continue
                else:
                    logger.info(f"[START] {server_name}")

                tool_configs, tool_list = await client.list_tools()
                self.server2configs[server_name].extend(tool_configs)
                self.server2mcp_tools[server_name] = tool_list

            self.servers[server_key] = client
            server_set.add(server_key)

        save_json('temp_files/all_server2configs.json', self.server2configs)

        # 对 self.server2mcp_tools 去重
        all_tools = []
        for srv, tool_list in self.server2mcp_tools.items():
            if srv in self.all_deleted_functions.keys():
                self.server2mcp_tools[srv] = [tl for tl in tool_list if tl not in self.all_deleted_functions[srv]]
                self.server2configs[srv] = [tl_config for tl_config in self.server2configs[srv] if tl_config['name'] not in self.all_deleted_functions[srv]]
            all_tools.extend(self.server2mcp_tools[srv])

        assert len(all_tools) == len(set(all_tools)), f"Duplicate tools found: {set([tool for tool in all_tools if all_tools.count(tool) > 1])}"

        self.all_functions = sum(self.server2configs.values(), [])

        # 删除所有工具中违法的 key
        for t_idx in range(len(self.all_functions)):
            if 'parameters' in self.all_functions[t_idx].keys() and 'properties' in self.all_functions[t_idx]['parameters'].keys():
                for p_name, p_config in self.all_functions[t_idx]['parameters']['properties'].items():
                    if 'optional' in p_config.keys():
                        del self.all_functions[t_idx]['parameters']['properties'][p_name]['optional']
                    if 'anyOf' in p_config.keys():
                        if any(['const' not in item_anyof for item_anyof in p_config['anyOf']]):
                            continue

                        # 如果里面出现的 const，就把整个 anyOf 都转到 description 里面，并删掉 anyOf
                        p_config['description'] = p_config.get('description', '') + json.dumps(p_config['anyOf'], ensure_ascii=False)
                        del p_config['anyOf']
                    if 'const' in p_config.keys():
                        del p_config['const']

            if 'parameters' in self.all_functions[t_idx].keys() and '$schema' in self.all_functions[t_idx]['parameters'].keys():
                del self.all_functions[t_idx]['parameters']['$schema']            

        if len(self.all_functions) == 0:
            raise NotImplementedError("No tools available. Please check the MCP configurations.")

        for srv, t_list in self.server2mcp_tools.items():
            for tl in t_list:
                self.mcp_tool2server[tl] = srv

        all_functions_path = os.path.join(os.path.dirname(self.log_messages_path), 'all_functions')
        Path(all_functions_path).mkdir(parents=True, exist_ok=True)
        save_json(os.path.join(all_functions_path, os.path.basename(self.log_messages_path).replace('.json', '_f.json')), self.all_functions)
        # save_json('temp_files/all_server2mcp_tools.json', self.server2mcp_tools)
        logger.info(f"All Servers: {server_set}")
        self.all_conversation = []

    def log_messages_to_file(self, conversation: List[Dict], log_messages_path: Optional[str] = None):
        """
        Log messages and function definitions to a JSONL file.
        """
        try:
            # Create directory if it doesn't exist
            if log_messages_path or self.log_messages_path:
                save_json(log_messages_path or self.log_messages_path, conversation)

        except Exception as e:
            logger.error(f"Error logging messages to {self.log_messages_path}: {str(e)}")

    async def _initialize_conversation(self) -> Tuple[List[Dict], List[Dict]]:
        """
        Initialize the conversation history.
        If a conversation path is provided, load it from there.
        Otherwise, start with an empty conversation.

        Returns:
            List[Dict]: The initialized conversation history.
        """
        if self.conversation_path:
            return load_json(self.conversation_path), (load_json(self.conversation_path.replace('.json', '_usage.json')) if os.path.isfile(self.conversation_path.replace('.json', '_usage.json')) else [])
        return [], []

    async def cleanup(self, conversation: Optional[List[Dict]] = None):
        """Clean up servers and log messages"""
        if conversation:
            self.log_messages_to_file(conversation)
        for cli in self.servers.values():
            if hasattr(cli, 'stop'):
                await cli.stop()
        self.servers.clear()

    async def get_tool_call(self, tool_calls) -> Tuple[List[Dict], List[Dict]]:
        tasks = []
        for tc in tool_calls:
            task = process_tool_call(tc, self.servers, self.mcp_tool2server)
            tasks.append(task)
        results_with_status = await asyncio.gather(*tasks, return_exceptions=True)

        tool_results_to_append, tool_results_raw = [], []
        for tc, cur_result in zip(tool_calls, results_with_status):
            if isinstance(cur_result, Exception):
                # If an exception occurred, format it as a tool error message
                error_result = {
                    "role": "tool",
                    "tool_call_id": tc["id"],
                    "name": tc["function"]["name"],
                    "content": f"Error: {str(cur_result)}"
                }
                tool_results_to_append.append(error_result)
                results_raw = f"Error: {str(cur_result)}"
                is_tool_success = False
                is_add_user_message = False
            elif isinstance(cur_result, tuple):
                result, results_raw, is_tool_success, is_add_user_message = cur_result
                tool_results_to_append.append(result)
            else:
                raise ValueError("Unexpected result type from process_tool_call")

            tool_results_raw.append({
                'name': tc["function"]["name"],
                'content': results_raw,
                'is_tool_success': is_tool_success,
                'is_add_user_message': is_add_user_message
            })

        return tool_results_to_append, tool_results_raw

    async def run(self, conversation, tool_use_limited: int = 99999, all_functions: List[str] | None = None, log_messages_path: Optional[str] = None, restart: Optional[bool] = False, model_cfg: dict = {}, finish_tools: list = [], info_graph: Graph | None = None, not_used_tools: list = []) -> Tuple[List[Dict], List[Dict]]:
        if 'step_summary' not in finish_tools:
            finish_tools.append('step_summary')

        self.log_messages_to_file(conversation, log_messages_path or self.log_messages_path)

        if restart:
            # 重启时把所有tool调用结果都重新执行一遍
            for response in conversation:
                if "tool_calls" in response.keys():
                    await self.get_tool_call(response["tool_calls"])

        cur_functions = self.all_functions if not all_functions else [i_tool for i_tool in self.all_functions if i_tool['name'] in all_functions] # 如果 all_functions 传入，则只使用指定的工具
        cur_functions = [i_tool for i_tool in cur_functions if i_tool['name'] not in not_used_tools]
        if all_functions:
            # 记录当前使用的工具数量
            self.all_update(upd_dict={"role": "log", "content": {"partial_function_length": len(cur_functions)}})

        tool_results = []
        while True:
            response, usage = generate_text(conversation, {**self.chosen_model, **model_cfg}, cur_functions) # 生成

            # 记录当前对话
            response = {'role': 'assistant', **response}
            conversation.append(response)
            self.log_messages_to_file(conversation, log_messages_path or self.log_messages_path)
            # logger.info(f"Added assistant message: {json.dumps(assistant_message, ensure_ascii=False)}")

            # 判断是否有tool调用，如果没有则结束对话
            if not response.get("tool_calls"):
                break

            # 有tool调用则执行tool调用
            tool_results_to_append, tool_results_raw = await self.get_tool_call(response["tool_calls"])
            tool_results.extend(tool_results_raw)

            # 记忆管理：
            try:
                for tr_to_append, tr in zip(tool_results_to_append, tool_results_raw):
                    if tr['name'] in ['context_refinement']:
                        is_refined = False
                        for idx_context in range(-1, -3, -1):
                            if conversation[idx_context]['role'] == 'tool':
                                cur_refine_prompt = context_refinement_prompt(
                                    tool_call_context=conversation[idx_context - 1],
                                    returned_context=conversation[idx_context],
                                    refine_prompt=tr['content']['refine_prompt']
                                )
                                if len(cur_refine_prompt) > 0:
                                    conversation_tool_refine = [{
                                        "role": "user",
                                        "content": cur_refine_prompt
                                    }]
                                    response_tool_refine, usage_tool_refine = generate_text(conversation_tool_refine, {**self.chosen_model, **model_cfg})
                                    conversation[idx_context]["content"] = response_tool_refine["content"]
                                    is_refined = True
                        if is_refined:
                            tr_to_append['content'] = json.dumps({"content": "Successfully refined the last tool return."}, ensure_ascii=False)
                        else:
                            tr_to_append['content'] = json.dumps({"content": "No tool return found to refine."}, ensure_ascii=False)
            except Exception as e:
                logger.error(f"Error processing tool results for context_refinement: {str(e)}")

            try:
                for tr_to_append, tr in zip(tool_results_to_append, tool_results_raw):
                    if tr['name'] in ['get_step_information']:
                        tr_to_append['content'] = info_graph.get_node_info(tr['content']['step_id'])
                        if len(tr_to_append['content']) == 0:
                            tr_to_append['content'] = json.dumps({"content": f"Error: step_id {tr['content']['step_id']} is not in history."}, ensure_ascii=False)

            except Exception as e:
                logger.error(f"Error processing tool results for parts generation: {str(e)}")

            # 记录当前tool调用结果
            tool_use_limited -= len(tool_results_to_append)
            tool_results_to_append = [i_tool_result for i_tool_result in tool_results_to_append if i_tool_result['content'] != 'NOT INCLUDE']
            if tool_results_to_append:
                conversation.extend(tool_results_to_append)
            self.log_messages_to_file(conversation, log_messages_path or self.log_messages_path)

            if tool_use_limited < 0 or (set(finish_tools) & set([tc["function"]["name"] for tc in response["tool_calls"]])):
                break


            base64_images = []
            other_text = ""
            try:
                for tr in tool_results_raw:
                    if tr['is_add_user_message'] and tr['is_tool_success']:
                        for cur_image in tr['content'].get('content', []):
                            if 'data' in cur_image.keys() and 'mimeType' in cur_image.keys():
                                base64_images.append(
                                    f"data:{cur_image['mimeType']};base64,{cur_image['data']}"
                                )
                            elif 'text' in cur_image.keys():
                                other_text += cur_image['text'] + "\n"

            except Exception as e:
                logger.error(f"Error processing tool results for base64-encoded image: {str(e)}")

            if len(base64_images) > 0:
                other_text = 'Now we have the multimedia data. Please continue.' if len(other_text) == 0 else other_text
                conversation.append({
                    "role": "user",
                    "content": [{
                        "type": "image_url",
                        "image_url": {"url": base64_image_str}
                    } for base64_image_str in base64_images] + \
                        [{"type": "text", "text": other_text}]
                })
                self.log_messages_to_file(conversation, log_messages_path or self.log_messages_path)


        if tool_use_limited < 0 and conversation[-1].get("role") == "tool" and not (set(finish_tools) & set([tc["function"]["name"] for tc in response["tool_calls"]])):
            # 如果最后是tool调用，则让模型基于已有信息给出最终答案
            conversation.append({"role": "user", "content": "The tool call attempts have been exhausted. Please provide the answer based on the information gathered so far."})
            response, usage = generate_text(conversation, self.chosen_model)
            response = {'role': 'assistant', **response}
            conversation.append(response)
            self.log_messages_to_file(conversation, log_messages_path or self.log_messages_path)

        return conversation, tool_results

    def all_update(self, conversation: Optional[List] = None, upd_idx: int = 0, upd_dict: Optional[Dict] = None, upd_list: Optional[List] = None, log_messages_path: Optional[str] = None):
        if upd_dict:
            self.all_conversation.append(upd_dict)
        elif upd_list:
            self.all_conversation.extend(upd_list)
        elif conversation:
            self.all_conversation.extend(conversation[upd_idx:])
        else:
            raise ValueError("Either conversation or upd_dict must be provided.")
        save_json(log_messages_path or self.log_messages_path.replace('.json', '_all.json'), self.all_conversation)


    async def run_tool(
            self,
            system_prompt,
            cold_start_prompt,
            act_system_prompt,
            act_prompt,
            replan_prompt,
            theme_prompt,
            task_prefix,
            tool_use_limited: int = 999999,
            cold_start_restart_file: Optional[str] = None,
            cold_start_model_cfg: dict = {},
        ):

        # 准备初始对话
        conversation = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": replace_str(cold_start_prompt, {'__theme__': theme_prompt, '__current_prompt__': 'First, think about what your tools can do, then start planning!'})}
        ]
        if cold_start_restart_file:
            conversation = load_json(cold_start_restart_file)
        else:
            logger.info(f"Cold Start ...")
            # --- Cold Start 生成初始 方向、plan ---
            conversation, cold_start_tool_results = await self.run(
                conversation=conversation,
                tool_use_limited=tool_use_limited,
                log_messages_path=self.log_messages_path.replace('.json', f'_{task_prefix}_cold_start.json'),
                restart=False,
                model_cfg=cold_start_model_cfg,
                not_used_tools=['context_refinement', 'get_step_information']
            )
        logger.info(f"Cold Start Finished.")
        # self.all_update(conversation=conversation)

        # 从对话中提取JSON格式的最新的计划步骤
        plan_in_json, is_extract = None, False
        cold_start_retry_num = 3

        while cold_start_retry_num > 0:
            logger.info(f"Extracting plan from cold start conversation, extracted attempts left: {cold_start_retry_num}")
            cold_start_retry_num -= 1

            # 找倒数3个消息：
            for message in conversation[-1:-4:-1]:
                if message["role"] == "assistant":
                    # 用正则表达式提取 ```json ... ``` 之间的内容
                    plan_in_json, is_extract = extract_json(message["content"])
                    if is_extract:
                        break
            if is_extract:
                logger.info(f"Successfully extracted plan.")
                break

            # 如果没有提取到有效的计划，则再尝试一次：
            conversation.append({
                "role": "user",
                "content": f"The previous JSON format is invalid with {plan_in_json}. Please correct it and provide a valid JSON format."
            })

            response, usage = generate_text(conversation, self.chosen_model) # 生成

            # 记录当前对话
            response = {'role': 'assistant', **response}
            conversation.append(response)
            # self.log_messages_to_file(conversation, self.log_messages_path.replace('.json', '_cold_start.json'))

        # 如果没有提取出计划，就直接放弃此次 infer
        if not is_extract:
            logger.warning(f"Failed to extract a valid plan after multiple attempts. Aborting the process.")
            return

        self.log_messages_to_file(conversation, self.log_messages_path.replace('.json', f'_{task_prefix}_cold_start.json'))


        self.all_update(upd_dict={"role": "log", "content": plan_in_json})
        logger.info("Initial plan successfully extracted and logged.")

        # --- Step 0 Act Agent 根据计划进行信息挖掘 ---
        info_graph = Graph()

        while len(info_graph.get_node_ids()) < len(plan_in_json['plan']):
            logger.info(f"Begin a new iteration, current info_graph nodes: {info_graph.get_node_ids()}")
            plan_in_json: Dict

            is_do = False
            for cur_idx, cur_plan in enumerate(plan_in_json['plan']):
                if cur_plan['status'] != 'pending':
                    continue

                # 判断前序任务是否都已处理过
                cur_dependencies = cur_plan['dependent_ids']
                is_depend_success = all([dep_id in info_graph.get_node_ids() for dep_id in cur_dependencies])
                if not is_depend_success:
                    continue

                is_do = True
                async def act():
                    # 组织 context，调用 act agent
                    act_conversation = [
                        {"role": "system", "content": act_system_prompt},
                        {"role": "user", "content": prepare_act_prompt(act_prompt=act_prompt, task_description=cur_plan['question'], dependent_task_results=info_graph.get_nodes_info(include_nodes=cur_dependencies))}
                    ]
                    act_conversation, tool_results = await self.run(
                        conversation=act_conversation,
                        tool_use_limited=tool_use_limited,
                        log_messages_path=self.log_messages_path.replace('.json', f'_{task_prefix}_act_{cur_plan["id"]}.json'),
                        not_used_tools=['get_step_information']
                    )
                    return act_conversation, tool_results

                logger.info(f"Executing step {cur_plan['id']} ..")
                cur_act_conversation, cur_tool_results = await act()
                self.all_update(conversation=cur_act_conversation)

                def get_act_summary(tool_results) -> Dict | None:
                    step_summary = None
                    for tr in tool_results[-1:-4:-1]:
                        if tr['name'] in ['step_summary']:
                            step_summary = tr['content']
                            break
                    return step_summary
                step_summary = get_act_summary(cur_tool_results)

                step_summary_retry_num = 3
                act_retry_conversation = deepcopy(cur_act_conversation)
                while step_summary is None and step_summary_retry_num > 0:
                    # 如果没有提取到有效的执行结果，则再试一次：
                    logger.info(f"Step {cur_plan['id']} did not produce a step_summary, retrying, attempts left: {step_summary_retry_num}")
                    self.all_update(upd_dict={"role": "log", "content": f"Step {cur_plan['id']} did not produce a step_summary. Retrying..."})
                    log_idx = len(act_retry_conversation)

                    act_retry_conversation.append({
                        "role": "user",
                        "content": f"Summarize the conversation by the `step_summary` tool."
                    })
                    act_retry_conversation, upd_tool_results = await self.run(
                        conversation=act_retry_conversation,
                        tool_use_limited=tool_use_limited,
                        log_messages_path=self.log_messages_path.replace('.json', f'_{task_prefix}_act_{cur_plan["id"]}.json'),
                        not_used_tools=['get_step_information']
                    )
                    self.all_update(conversation=act_retry_conversation, upd_idx=log_idx)

                    step_summary = get_act_summary(upd_tool_results)
                    step_summary_retry_num -= 1

                if step_summary is None:
                    # Sum Agent 接入，生成 step_summary
                    logger.info(f"Step {cur_plan['id']} did not produce a step_summary after multiple attempts. Using Summary Agent to summarize the conversation.")
                    summary_conversation = deepcopy(cur_act_conversation)[1:] # 去掉 system prompt
                    summary_conversation.append({
                        "role": "user",
                        "content": f"Please summarize the conversation by the `step_summary` tool."
                    })
                    summary_conversation, sum_tool_results = await self.run(
                        conversation=summary_conversation,
                        tool_use_limited=tool_use_limited,
                        all_functions=['step_summary'],
                        log_messages_path=self.log_messages_path.replace('.json', f'_{task_prefix}_act_{cur_plan["id"]}.json'),
                        not_used_tools=['context_refinement', 'get_step_information']
                    )
                    step_summary = get_act_summary(sum_tool_results)
                    self.all_update(conversation=cur_act_conversation, upd_idx=log_idx)
                    # step_final_response = cur_act_conversation[-1].get("content", '')

                    step_summary_retry_num -= 1

                if step_summary is None:
                    step_summary = {
                        "is_success": False,
                        "content": "No valid step_summary produced.",
                        "observations": "Could not produce a valid step_summary after multiple attempts."
                    }
                logger.info(f"Step {cur_plan['id']} successfully extracted summary.")

                # 放入 info_graph 标志着已经执行完毕
                step_summary: Dict
                info_graph.add_node(
                    node_id=cur_plan['id'],
                    question=cur_plan['question'],
                    node_info=step_summary,
                    status="executed",
                    dependent_ids=cur_dependencies
                )
                self.all_update(upd_dict={"role": "log", "content": f"Step {cur_plan['id']} executed successfully.", "info_graph": info_graph.graph})
                plan_in_json["plan"][cur_idx]['status'] = "executed"


                # 查看是否需要调整计划
                cur_plan_str = ''
                for p in plan_in_json['plan']:
                    if p['id'] not in info_graph.get_node_ids():
                        cur_plan_str += json.dumps(p, ensure_ascii=False) + '\n'
                    else:
                        cur_plan_str += info_graph.get_node_info(p['id']) + '\n'
                cur_replan_conversation = [
                    {"role": "system", "content": system_prompt},
                    {
                        "role": "user",
                        "content": replace_str(
                            cold_start_prompt,
                            {
                                '__theme__': theme_prompt,
                                '__current_prompt__': replace_str(replan_prompt, {'__current_plan__': cur_plan_str})
                            }
                        )
                    }
                ]
                logger.info(f"Re-planning for step {cur_plan['id']} ...")
                cur_replan_conversation, replan_tool_results = await self.run(
                    conversation=cur_replan_conversation,
                    tool_use_limited=tool_use_limited,
                    log_messages_path=self.log_messages_path.replace('.json', f'_{task_prefix}_replan_{cur_plan["id"]}.json'),
                    not_used_tools=['context_refinement']
                )
                self.all_update(conversation=cur_replan_conversation)

                replan_in_json, replan_is_extract = None, False
                for replan_message in cur_replan_conversation[-1:-4:-1]:
                    if replan_message["role"] == "assistant":
                        # 用正则表达式提取 ```json ... ``` 之间的内容
                        replan_in_json, replan_is_extract = extract_json(replan_message["content"])
                        if replan_is_extract:
                            logger.info(f"Step {cur_plan['id']} re-planning produced new plan.")
                            break

                replan_in_json: Dict
                if replan_is_extract:
                    for p in replan_in_json['plan']:
                        if p['id'] in info_graph.get_node_ids():
                            # 保留已执行的步骤
                            continue
                        else:
                            # 更新未执行的步骤
                            for idx in range(len(plan_in_json['plan'])):
                                if plan_in_json['plan'][idx]['id'] == p['id']:
                                    plan_in_json['plan'][idx] = p
                                    break
                    self.all_update(upd_dict={"role": "log", "content": f"Step {cur_plan['id']} re-planning succeeded.", "plan": plan_in_json})
                    logger.info(f"Step {cur_plan['id']} re-planning succeeded.")
                elif "requires no adjustments" in cur_replan_conversation[-1].get("content", '').strip().lower():
                    self.all_update(upd_dict={"role": "log", "content": f"Step {cur_plan['id']} re-planning requires no adjustments."})
                    logger.info(f"Step {cur_plan['id']} re-planning requires no adjustments.")
                else:
                    self.all_update(upd_dict={"role": "log", "content": f"Error: Step {cur_plan['id']} re-planning did not produce a valid JSON. Continuing with the previous plan."})
                    logger.info(f"Step {cur_plan['id']} re-planning did not produce a valid JSON.")
