"""
tool_router.py
MLLM as router to decide use tool to gather more information to help PRM judge the action.
"""

import json
import re
import copy
import base64
import os
from datetime import datetime

from io import BytesIO
from PIL import Image
from typing import List, Dict, Any, Optional

from android_world.agents.BoN.Tool.qwen_api import parallel_call_gpt_qwen
from android_world.agents.BoN.vllm_api import parallel_call_vllm
from android_world.agents.BoN.Tool.template_instruct import eval_prompt, system_prompt, system_prompt_omni_parser, system_prompt_Point, system_prompt_OmniParser_Point
from android_world.agents.BoN.Tool.base_manager import ToolManager
from android_world.agents.BoN.Tool.tool_server.utils.utils import pil_to_base64

from android_world.agents.BoN.Tool.tool_info_process import tool_info_process
def _parse_tool_config(model_response: str) -> Optional[Dict]:
    """从模型的JSON响应中解析出工具调用配置。"""
    try:
        actions_pattern = r'Action:\s*(\{.*?\})'
        match = re.search(actions_pattern, model_response, re.DOTALL)
        if not match:
            # actions_pattern_fallback = r'"actions"\s*:\s*(\[.*?\])'
            actions_pattern_fallback = r'"actions"\s*:\s*(\[\s*\{[\s\S]*\}\s*\])'
            match_fallback = re.search(actions_pattern_fallback, model_response, re.DOTALL)
            if not match_fallback:
                print("警告: 在模型响应中未找到 'Action:' 或 '\"actions\": [...]' 结构。")
                return None
            action_str = json.loads(match_fallback.group(1))[0]
        else:
            action_str = json.loads(match.group(1))

        if 'API_name' in action_str and 'API_params' in action_str:
            return {'API_name': action_str['API_name'], 'API_params': action_str['API_params']}
        elif 'name' in action_str and 'arguments' in action_str:
            return {'API_name': action_str['name'], 'API_params': action_str['arguments']}

    except (json.JSONDecodeError, IndexError, KeyError) as e:
        print(f"错误: 解析工具配置失败: {e}\n模型响应: {model_response}")
    return None


def _build_prompt(initial_prompt: str, history_log: Dict[str, str], oi_instruction: bool) -> str:
    """构建发送给模型的完整文本提示。"""
    prompt_parts = [f"User Question: {initial_prompt}\n Try to use tool provided to get external information."]

    if history_log:
        history_str = json.dumps(history_log, ensure_ascii=False)
        prompt_parts.append(f"You have already taken some steps. Here is the history of your actions and their observations:")
        prompt_parts.append(f"Current tool calling history: {history_str}\n")

    if oi_instruction:
        prompt_parts.append(
            "**Your Task (OI - Observation & Introspection):**\n"
            "1. **Summarize:** Briefly summarize what you have learned from the history.\n"
            "2. **Decide:** Based on your summary and the initial goal, decide on the next step. Do you have enough information to answer the request?\n"
            "   - If YES, call the tool `{\"name\": \"Terminate\", \"arguments\": {\"ans\": \"<your final answer>\"}}`.\n"
            "   - If NO, call another tool to get the missing information.\n **Don't call any tool that you have used before.**"
        )
    else:
        prompt_parts.append(
            "Based on the user's request and the current screen, determine if you need to call a tool to gather more information. "
            "If so, specify the tool and its parameters. If not, call the \"Terminate\" tool with the final answer."
            "Don't call any tool that you have used before."
        )

    return "\n".join(prompt_parts)


def tool_router(raw_prompt: str, image: Image.Image, debug: bool = False):
    """
    重构后的Tool Router函数。
    - 引入Debug模式，用于保存中间过程的图片和对话日志。
    - 明确返回结果中图像是否被编辑过。
    - 处理无效模型响应的健壮性。
    - 增加防止重复工具调用的机制，检测与所有已调用工具的重复。
    """
    # --- 1. 初始化 ---
    max_rounds = 2
    CONTROLLER_ADDR = "http://localhost:20001"
    model = "qwen2.5-vl-72b-instruct"
    temperature = 0.0
    max_tokens = 3000
    max_processes = 4
    
    debug_dir = None
    if debug:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        debug_dir = os.path.join("debug_logs_20250801-tool-server-num37", timestamp)
        os.makedirs(debug_dir, exist_ok=True)
        print(f"Debug模式已启动。日志和图片将保存到: {debug_dir}")

    tool_manager = ToolManager(controller_url_location=CONTROLLER_ADDR)
    if not tool_manager.available_tools:
        print("错误：ToolManager 未能发现任何可用工具，流程中止。")
        return {
            "tool_chain": [{"error": "No tools available"}],
            "final_images": [pil_to_base64(image)],  # 修改为列表
            "image_is_edited": False
        }

    current_image = image
    image_was_edited = False
    history_log = {} 
    tool_chain_history = [] 
    debug_conversation_log = []
    used_tools = set()
    
    final_images = [pil_to_base64(current_image)]  # 初始化时包含原始图像

    # --- 2. 推理循环 ---
    for current_round in range(max_rounds):
        print(f"\n{'='*20} 开始轮次 {current_round + 1}/{max_rounds} {'='*20}")

        # 2.1 构建当前轮次的Prompt
        prompt_text = _build_prompt(raw_prompt, history_log, oi_instruction=(current_round > 0))
        # prompt_text += "Please use tool'DrawHorizontalLineByY' first." 
        current_image_b64 = pil_to_base64(current_image)
        
        messages = [
            {"role": "system", "content": [{"type": "text", "text": system_prompt_OmniParser_Point}]},
            {"role": "user","content": [{"type": "text", "text": prompt_text}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{current_image_b64}"}}]},
        ]

        # 2.2 调用模型
        # responses = parallel_call_gpt_qwen([messages], model, temperature, max_tokens, max_processes)
        responses = parallel_call_vllm([messages], model, temperature, max_tokens, max_processes,port="8005")
        model_response_text = responses[0]
        print(f"模型响应成功！")

        # 2.3 解析工具配置
        tool_cfg = _parse_tool_config(model_response_text)

        if not tool_cfg:
            print("警告: 模型响应格式不正确，无法解析出有效行动。跳过此轮，尝试下一轮。")
            if debug:
                debug_conversation_log.append({
                    "round": current_round + 1, "prompt": prompt_text, "model_response": model_response_text,
                    "parsed_tool_config": "Error: Parsing Failed", "tool_chain_step": "Skipped"
                })
            continue

        api_name = tool_cfg.get('API_name')

        # 2.4 判断是否终止
        if api_name == 'Terminate':
            print("模型决定终止。流程结束。")
            final_ans = "流程正常终止，但未提供最终答案。"
            if 'ans' in tool_cfg.get('API_params', {}):
                final_ans = tool_cfg['API_params']['ans']
            
            step_summary = {"tool": "Terminate", "params": tool_cfg.get('API_params', {}), "observation": final_ans}
            tool_chain_history.append(step_summary)
            
            if debug:
                debug_conversation_log.append({
                    "round": current_round + 1, "prompt": prompt_text, "model_response": model_response_text,
                    "parsed_tool_config": tool_cfg, "tool_chain_step": step_summary
                })
            break

        # --- 检查工具是否已被使用过 ---
        if api_name in used_tools:
            print(f"警告: 模型尝试重复调用工具 '{api_name}'。此轮调用将被跳过。")
            # 构建特殊的历史记录，告知模型它的行为被阻止了
            observation_log_str = f"Error: You attempted to call a tool ('{api_name}') that was already used before. This action was skipped. Please choose a different tool that you haven't used yet."
            action_log_str = json.dumps({"name": api_name, "arguments": tool_cfg.get('API_params', {})}, ensure_ascii=False)
            history_log[f"action_{current_round}"] = action_log_str
            history_log[f"observation_{current_round}"] = observation_log_str
            
            # 在工具链和Debug日志中也记录这次跳过
            step_summary = { "action": api_name, "params": tool_cfg.get('API_params', {}), "observation": "Skipped due to repetition.", "status": "Skipped" }
            tool_chain_history.append(step_summary)
            if debug:
                debug_conversation_log.append({
                    "round": current_round + 1, "prompt": prompt_text, "model_response": model_response_text,
                    "parsed_tool_config": tool_cfg, "tool_chain_step": step_summary
                })
            
            continue # 直接进入下一轮循环

        # 2.5 准备并执行工具调用
        api_params = tool_cfg.get('API_params', {})
        call_params = copy.deepcopy(api_params)

        if 'image' in api_params:
            call_params['image'] = pil_to_base64(current_image)
        tool_response = tool_manager.call_tool(tool_name=api_name, params=call_params)
        print(f"工具 '{api_name}' 文本响应成功！")
        
        # 将成功调用的工具添加到已使用工具集合中
        used_tools.add(api_name)

        # 2.6 处理工具返回结果并更新历史
        action_log_str = json.dumps({"name": api_name, "arguments": api_params}, ensure_ascii=False)
        observation_log_str = ""
        step_summary = {"tool": api_name, "params": api_params}

        if tool_response.get("error_code", 1) != 0:
            print(f"错误: 工具 '{api_name}' 执行失败: {tool_response.get('text')}")
            observation_log_str = f"Error: Tool '{api_name}' failed. Details: {tool_response.get('text')}"
            step_summary["observation"] = observation_log_str
            step_summary["status"] = "Error"
            tool_chain_history.append(step_summary)
            break
        else:
            print(f"工具 '{api_name}' 执行成功。")
            observation_log_str = tool_response.get('text', 'No text output.')
            step_summary["observation"] = observation_log_str
            step_summary["status"] = "Success"
            
            edited_image_b64 = tool_response.get("edited_image")
            if edited_image_b64:
                print("工具返回了编辑过的图片。")
                image_was_edited = True
                edited_image = Image.open(BytesIO(base64.b64decode(edited_image_b64)))
                
          
                if api_name == "crop":
                    print("检测到crop工具，将子图添加到返回列表。")
                    final_images.append(edited_image_b64)


                # #  当测试Point+OCR的时候消融SOM的影响，需要取消注释；如果纯粹测试OCR，还需要将BoN中的add edit img部分消融，rollout=1
                # elif api_name == "omni_parser":
                #     pass

                else:
                    # 其他工具则更新当前图像
                    current_image = edited_image
                    # 更新final_images列表中的第一个元素为最新的主图像
                    final_images[0] = edited_image_b64

                if debug and debug_dir:
                    save_path = os.path.join(debug_dir, f"round_{current_round+1}_edited_by_{api_name}.png")
                    try:
                        edited_image.save(save_path)
                        print(f"Debug: 已保存编辑后的图片到 {save_path}")
                    except Exception as e:
                        print(f"Debug: 保存编辑后的图片失败. 错误: {e}")
            
            tool_chain_history.append(step_summary)

        history_log[f"action_{current_round}"] = action_log_str
        history_log[f"observation_{current_round}"] = observation_log_str
        
        if debug:
            debug_conversation_log.append({
                "round": current_round + 1, "prompt": prompt_text, "model_response": model_response_text,
                "parsed_tool_config": tool_cfg, "tool_chain_step": step_summary
            })

    # --- 3. 返回最终结果 ---
    if debug and debug_dir:
        log_file_path = os.path.join(debug_dir, "conversation_log.json")
        try:
            with open(log_file_path, 'w', encoding='utf-8') as f:
                json.dump(debug_conversation_log, f, indent=2, ensure_ascii=False)
            print(f"Debug: 已保存完整对话日志到 {log_file_path}")
        except Exception as e:
            print(f"Debug: 保存对话日志文件失败. 错误: {e}")
    
    # 【修改】更新最终的返回结构，将final_image改为final_images列表
    tool_chain_filter = []
    final_result = {
        "tool_chain": tool_chain_history,
        "tool_chain_filter": tool_chain_filter,
        "final_images": final_images,  # 修改为图像列表
        "image_is_edited": image_was_edited
    }
    
    tool_chain_filter = tool_info_process(raw_prompt ,final_result["tool_chain"])
    final_result["tool_chain_filter"] = tool_chain_filter
    print("tool_chain_filter:", tool_chain_filter)
    if debug:
        final_result_1 = copy.deepcopy(final_result) 
      
        debug_conversation_log.append({"tool_result": final_result_1})

    print(f"\n流程完成，返回最终结果")
    # breakpoint()
    return final_result, debug_conversation_log

# --- 示例调用 ---
if __name__ == '__main__':
    user_prompt = "The current user goal/request is: Record an audio clip and save it with name \"presentation_fGwr.m4a\" using Audio Recorder app.\n\nHere is a history of what you have done so far:\nYou just started, no action has been performed yet." 
    image_path = r"path/to/your/screenshot.png" 
    try:
        user_image = Image.open(image_path)
        tool_router(raw_prompt=user_prompt, image=user_image, debug=True) 
    except FileNotFoundError:
        print(f"示例图片未找到，请检查路径: {image_path}")
        user_image = Image.new('RGB', (400, 300), color='grey')
        tool_router(raw_prompt=user_prompt, image=user_image, debug=True)
