#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
VLLM API沙盒工具调用客户端：基于run_sandbox_batch.py架构，使用自己部署的VLLM API

- 使用标准的message格式：system -> user -> assistant -> tool -> assistant -> ...
- 模型发起工具调用后，工具返回结果以 role: tool 的形式添加到messages中
- ExecuteCode -> run_code
- ExecuteShell -> sandbox_execute_command
- 并行执行（ThreadPoolExecutor）
- 每题最大迭代轮次保护
- 结果JSON包含 final_model_output / last_iteration_output / iteration_traces（含每轮完整messages）

Usage:
  python tool_vllm_client_sandbox.py -i sandbox_from_doc.json -o results.json -m stu-qwen2.5-72B-Instruct --max-iter 10 --workers 50 --seed 42
"""

from __future__ import annotations
import argparse
import json
import os
import random
import re
from typing import Any, Dict, List, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
import base64
import threading

# ====== 使用VLLM API而不是外部API ======
try:
    from vllm_api import call_vllm_api, VLLMChatSession
except ImportError:
    print("⚠️ vllm_api模块未找到，工具调用功能仍可正常使用")

from fileagent_libs.fileagent_mcp_host import (
    create_sandbox,
    sandbox_write_file,
    sandbox_execute_command,
    destroy_sandbox,
    run_code,
)
# ================================================================

# ----------------- Regex helpers -----------------
RE_TOOL_BLOCK = re.compile(
    r"(?:<tool_calls>(?P<block>[\s\S]*?)</tool_calls>|<tool_call>(?P<block2>[\s\S]*?)</tool_call>)",
    re.IGNORECASE
)

def _normalize_tool_call(call: Dict[str, Any]) -> Dict[str, Any]:
    """规范化工具调用格式，统一为 tool_name + input_parameters"""
    normalized = {}
    
    # 处理工具名称：name 或 tool_name
    if "tool_name" in call:
        normalized["tool_name"] = call["tool_name"]
    elif "name" in call:
        normalized["tool_name"] = call["name"]
    
    # 处理参数：arguments 或 input_parameters
    if "input_parameters" in call:
        normalized["input_parameters"] = call["input_parameters"]
    elif "arguments" in call:
        normalized["input_parameters"] = call["arguments"]
    else:
        normalized["input_parameters"] = {}
    
    return normalized

def parse_tool_calls(used_output: str) -> List[Dict[str, Any]]:
    """Extract the JSON inside the first <tool_calls>/<tool_call>, return as list[dict]."""
    m = RE_TOOL_BLOCK.search(used_output)
    if not m:
        return []
    block = (m.group("block") or m.group("block2") or "").strip()
    
    # 尝试直接JSON解析
    try:
        obj = json.loads(block)
        calls = obj if isinstance(obj, list) else [obj]
        # 规范化字段名
        return [_normalize_tool_call(call) for call in calls]
    except Exception:
        pass
    
    # 使用手动解析作为后备方案（参考tool_vllm_client.py的逻辑）
    return _manual_parse_tool_call(block)

def _manual_parse_tool_call(text: str) -> List[Dict[str, Any]]:
    """手动解析工具调用，参考tool_vllm_client.py的实现"""
    result_calls = []
    
    # 首先尝试简单的分割方法
    if text.strip().startswith('[') and text.strip().endswith(']'):
        try:
            # 移除首尾的方括号
            inner_text = text.strip()[1:-1].strip()
            # 简单分割（假设工具调用之间用逗号分隔）
            parts = _split_json_objects(inner_text)
            
            for part in parts:
                part = part.strip()
                if part.startswith('{') and part.endswith('}'):
                    result_calls.extend(_parse_single_tool_call(part))
            
            if result_calls:
                return result_calls
        except:
            pass
    
    # 如果上面的方法失败，使用正则表达式
    # 支持两种格式：{"tool_name": ...} 和 {"name": ...}
    tool_start_patterns = [r'\{"tool_name":', r'\{"name":']
    
    for pattern in tool_start_patterns:
        for match in re.finditer(pattern, text):
            start_pos = match.start()
            # 从这个位置开始，找到匹配的结束大括号
            end_pos = _find_matching_brace(text, start_pos)
            if end_pos != -1:
                full_tool_text = text[start_pos:end_pos+1]
                parsed_calls = _parse_single_tool_call(full_tool_text)
                result_calls.extend(parsed_calls)
    
    return result_calls

def _split_json_objects(text: str) -> List[str]:
    """分割JSON对象"""
    parts = []
    current_part = ""
    brace_count = 0
    in_string = False
    escape_next = False
    
    for char in text:
        if escape_next:
            current_part += char
            escape_next = False
            continue
            
        if char == '\\':
            escape_next = True
            current_part += char
            continue
            
        if char == '"' and not escape_next:
            in_string = not in_string
            
        if not in_string:
            if char == '{':
                brace_count += 1
            elif char == '}':
                brace_count -= 1
                
        current_part += char
        
        # 如果找到完整的对象
        if not in_string and brace_count == 0 and current_part.strip():
            if current_part.strip().endswith(','):
                current_part = current_part.rstrip(',').strip()
            parts.append(current_part.strip())
            current_part = ""
    
    # 添加最后一部分
    if current_part.strip():
        parts.append(current_part.strip())
        
    return parts

def _find_matching_brace(text: str, start_pos: int) -> int:
    """找到匹配的结束大括号"""
    brace_count = 0
    in_string = False
    escape_next = False
    
    for i in range(start_pos, len(text)):
        char = text[i]
        
        if escape_next:
            escape_next = False
            continue
            
        if char == '\\':
            escape_next = True
            continue
            
        if char == '"' and not escape_next:
            in_string = not in_string
            continue
            
        if not in_string:
            if char == '{':
                brace_count += 1
            elif char == '}':
                brace_count -= 1
                
                if brace_count == 0:
                    return i
    
    return -1

def _parse_single_tool_call(text: str) -> List[Dict[str, Any]]:
    """解析单个工具调用，支持两种格式：
    1. {"tool_name": "Search", "input_parameters": {...}}
    2. {"name": "Search", "arguments": {...}}
    """
    # 提取tool_name或name
    name_pattern_old = r'"tool_name":\s*"([^"]+)"'
    name_pattern_new = r'"name":\s*"([^"]+)"'
    
    name_match = re.search(name_pattern_old, text)
    if not name_match:
        name_match = re.search(name_pattern_new, text)
    
    if not name_match:
        return []
        
    tool_name = name_match.group(1)
    
    # 提取input_parameters或arguments部分
    params_pattern_old = r'"input_parameters":\s*(\{.*\})'
    params_pattern_new = r'"arguments":\s*(\{.*\})'
    
    params_match = re.search(params_pattern_old, text, re.DOTALL)
    if not params_match:
        params_match = re.search(params_pattern_new, text, re.DOTALL)
    
    params = {}
    if params_match:
        params_str = params_match.group(1)
        
        # 尝试各种方法解析参数
        try:
            params = json.loads(params_str)
        except Exception:
            try:
                import ast
                params = ast.literal_eval(params_str)
            except Exception:
                # 如果JSON和ast都失败，尝试手动提取关键字段
                params = _extract_params_manually(params_str)
    
    return [{
        "tool_name": tool_name,
        "input_parameters": params
    }]

def _extract_params_manually(params_str: str) -> Dict[str, Any]:
    """手动提取参数，处理特殊情况"""
    params = {}
    
    # 提取code参数（最常见的复杂参数）
    code_patterns = [
        r'"code":\s*"(.*?)"(?=\s*[,}])',  # 标准情况
        r'"code":\s*"(.*?)$',             # 到字符串末尾
        r'"code":\s*\'(.*?)\'(?=\s*[,}])', # 单引号情况
    ]
    
    for pattern in code_patterns:
        code_match = re.search(pattern, params_str, re.DOTALL)
        if code_match:
            code = code_match.group(1)
            # 处理转义字符
            try:
                # 使用json.loads来处理转义字符，这是最安全的方法
                code = json.loads(f'"{code}"')
            except:
                # 如果json.loads失败，使用手动替换
                code = code.replace('\\n', '\n')
                code = code.replace('\\t', '\t')
                code = code.replace('\\"', '"')
                code = code.replace("\\'", "'")
                code = code.replace('\\\\', '\\')
            params["code"] = code
            break
    
    # 提取其他简单参数
    simple_params = [
        ("command", r'"command":\s*"([^"]+)"'),
        ("file_path", r'"file_path":\s*"([^"]+)"'),
        ("max_rows", r'"max_rows":\s*(\d+)'),
        ("format", r'"format":\s*"([^"]+)"'),
        ("output_file", r'"output_file":\s*"([^"]+)"'),
    ]
    
    for param_name, pattern in simple_params:
        match = re.search(pattern, params_str)
        if match:
            value = match.group(1)
            if param_name == "max_rows":
                params[param_name] = int(value)
            else:
                params[param_name] = value
    
    return params

def extract_answer(text: str) -> Optional[str]:
    m = re.search(r"\*\*Answer\*\*:\s*(.*)", text, flags=re.DOTALL)
    if m:
        return m.group(1).strip()
    return None

def ensure_system_prompt() -> str:
    try:
        prompt_path = os.path.join(os.path.dirname(__file__), "simple_system_prompt-qwen.txt")
        with open(prompt_path, "r", encoding="utf-8") as f:
            return f.read()
    except Exception:
        return ""

# ----------------- Normalizers for sandbox returns -----------------
def _as_dict_return(obj: Any) -> Any:
    """
    兼容 .as_dict()['return'] 模式：
      - JSON字符串：'{"session_id":"..."}'
      - dict：{"code":0, "stdout":"...", "stderr":"..."}
      - 普通字符串
    """
    if hasattr(obj, "as_dict"):
        d = obj.as_dict()
        val = d.get("return")
        if isinstance(val, str):
            try:
                return json.loads(val)
            except Exception:
                return val
        return val
    return obj

def create_session_id() -> str:
    resp = create_sandbox()
    val = _as_dict_return(resp)
    if isinstance(val, dict) and "session_id" in val:
        return val["session_id"]
    try:
        j = json.loads(val)
        if "session_id" in j:
            return j["session_id"]
    except Exception:
        pass
    raise RuntimeError(f"CreateSandBox returned unexpected payload: {val!r}")

def exec_shell(session_id: str, command: str) -> Dict[str, Any]:
    resp = sandbox_execute_command(session_id=session_id, command=command)
    val = _as_dict_return(resp)
    if isinstance(val, dict):
        code = int(val.get("code", 1))
        stdout = str(val.get("stdout", ""))
        stderr = str(val.get("stderr", ""))
    else:
        code, stdout, stderr = 1, "", str(val)
    return {"code": code, "stdout": stdout, "stderr": stderr}

# ----------------- 单题执行器（单 system / user / assistant，使用VLLM API） -----------------
class SingleQuestionRunner:
    def __init__(self, model_name: str, debug: bool = False):
        # 注意：批量并行时默认 debug=False 避免并发打印打架。
        self.debug = debug
        self.model_name = model_name
        self._sys_prompt = ensure_system_prompt()
        self.messages: List[Dict[str, str]] = []
        if self._sys_prompt:
            self.messages.append({"role": "system", "content": self._sys_prompt})
        
        # 尝试创建VLLM会话
        try:
            self.chat_session = VLLMChatSession(model_name)
        except (NameError, ImportError, ModuleNotFoundError, Exception) as e:
            print(f"⚠️ VLLMChatSession 不可用: {e}")
            self.chat_session = None

    def _print(self, *a):
        if self.debug:
            print(*a, flush=True)

    def _call_llm(self) -> tuple[str, dict]:
        """调用VLLM API，返回(回复内容, token统计信息)"""
        if not self.chat_session:
            self._print("❌ VLLM会话未初始化，无法调用LLM")
            return "", {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
        
        max_retries = 50
        retry_delay = [1, 1, 1, 1, 1, 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2]  # 递增延迟：1秒, 2秒, 5秒
        
        for attempt in range(max_retries):
            try:
                # 使用VLLM会话的client进行调用
                client = self.chat_session.client
                rsp = client.chat.completions.create(
                    model=self.model_name,
                    messages=self.messages,
                    temperature=0.7,
                    max_tokens=4096,
                    timeout=60,  # 设置60秒超时
                )
                out = rsp.choices[0].message.content or ""
                
                # 检查是否真的是空输出以及原因
                if not out and hasattr(rsp.choices[0], 'finish_reason'):
                    finish_reason = rsp.choices[0].finish_reason
                    self._print(f"⚠️ 模型返回空输出，finish_reason: {finish_reason}")
                    if finish_reason == "length":
                        self._print("⚠️ 可能是输出长度限制导致的空输出")
                    elif finish_reason == "stop":
                        self._print("⚠️ 模型正常停止但没有生成内容")
                
                # 提取token统计信息
                usage = rsp.usage
                token_stats = {
                    "input_tokens": getattr(usage, 'prompt_tokens', 0),
                    "output_tokens": getattr(usage, 'completion_tokens', 0),
                    "total_tokens": getattr(usage, 'total_tokens', 0)
                }
                
                self._print(f"\n🧠 LLM 原始回复长度: {len(out)}")
                self._print(f"📊 Token统计: 输入={token_stats['input_tokens']}, 输出={token_stats['output_tokens']}, 总计={token_stats['total_tokens']}")
                
                # VLLM API返回的output_tokens就是模型实际生成内容的token数
                # <tool_return>是系统拼接的，不计入模型输出
                self._print(f"💡 模型实际生成内容token数: {token_stats['output_tokens']}")
                
                # 成功调用，返回结果
                return out, token_stats
                
            except Exception as e:
                error_msg = str(e)
                self._print(f"❌ 第{attempt+1}次调用VLLM API失败: {type(e).__name__}: {e}")
                
                # 检查是否是上下文长度超限错误
                if "maximum context length" in error_msg and "tokens" in error_msg:
                    self._print(f"🚫 检测到上下文长度超限错误，停止重试")
                    return "", {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
                
                # 如果还有重试机会
                if attempt < max_retries - 1:
                    delay = retry_delay[attempt]
                    self._print(f"🔄 等待{delay}秒后重试...")
                    import time
                    time.sleep(delay)
                else:
                    self._print(f"❌ 重试{max_retries}次后仍然失败，返回空结果")
                    return "", {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
    

    # 工具分发
    def _run_execute_code(self, session_id: str, params: Dict[str, Any]) -> Dict[str, Any]:
        code_text = params.get("code", "")
        if not isinstance(code_text, str):
            code_text = str(code_text)
        try:
            resp = run_code(code_text, language="python", session_id=session_id)
            # 你那边 run_code.result 是 JSON 字符串，这里按你改法解析
            try:
                resp = json.loads(resp.result)
            except Exception:
                # 如果不是这种风格，也兼容 dict / object
                pass
            if hasattr(resp, "__dict__"):
                rc = int(getattr(resp, "code", getattr(resp, "returncode", 0)))
                stdout = str(getattr(resp, "stdout", ""))
                stderr = str(getattr(resp, "stderr", ""))
            elif isinstance(resp, dict):
                rc = int(resp.get("code", 0))
                stdout = str(resp.get("stdout", ""))
                stderr = str(resp.get("stderr", ""))
            else:
                rc, stdout, stderr = 1, "", f"[run_code] unexpected return type: {type(resp).__name__}"
        except Exception as e:
            rc, stdout, stderr = 1, "", f"[run_code ERROR] {type(e).__name__}: {e}"

        return {"code": rc, "stdout": stdout, "stderr": stderr}

    def _run_execute_shell(self, session_id: str, params: Dict[str, Any]) -> Dict[str, Any]:
        command = params.get("command", "")
        if not isinstance(command, str):
            command = str(command)
        try:
            resp = exec_shell(session_id=session_id, command=command)
            rc = resp["code"]; stdout = resp["stdout"]; stderr = resp["stderr"]
        except Exception as e:
            rc, stdout, stderr = 1, "", f"[sandbox_execute_command ERROR] {type(e).__name__}: {e}"
        return {"code": rc, "stdout": stdout, "stderr": stderr}

    def _dispatch_tool(self, session_id: str, call: Dict[str, Any]) -> Dict[str, Any]:
        # 支持两种格式：tool_name/input_parameters 和 name/arguments
        tool_name = (call or {}).get("tool_name") or (call or {}).get("name")
        params = (call or {}).get("input_parameters") or (call or {}).get("arguments") or {}
        if tool_name == "ExecuteCode":
            return self._run_execute_code(session_id, params)
        elif tool_name == "ExecuteShell":
            return self._run_execute_shell(session_id, params)
        return {"code": 1, "stdout": "", "stderr": f"Unknown tool: {tool_name}"}

    def _extract_tool_calls(self, output_str: str) -> str:
        """提取工具调用内容 - 只处理第一个工具调用，支持两种标签格式"""
        # 首先尝试 <tool_calls>
        start_pos = output_str.find('<tool_calls>')
        end_pos = output_str.find('</tool_calls>')
        tag_len = len('<tool_calls>')
        
        # 如果没找到，尝试 <tool_call>
        if start_pos == -1 or end_pos == -1:
            start_pos = output_str.find('<tool_call>')
            end_pos = output_str.find('</tool_call>')
            tag_len = len('<tool_call>')
        
        if start_pos == -1 or end_pos == -1:
            return ""
        
        # 提取第一个工具调用的内容
        start_content = start_pos + tag_len
        tool_call_content = output_str[start_content:end_pos].strip()
        
        if self.debug:
            self._print(f"🔧 提取到第一个工具调用: {tool_call_content}")
        
        return tool_call_content

    def _execute_tools_from_calls(self, session_id: str, calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """执行工具调用并返回结果"""
        tool_returns = []
        for call in calls:
            result = self._dispatch_tool(session_id, call)
            tool_returns.append(result)
        return tool_returns

    def run_question(self, item: Dict[str, Any], index: int, max_iterations: int = 3) -> Dict[str, Any]:
        # 准备标签
        task_id = item.get("task_id")
        label = task_id or f"Q{index}"

        # 一题一沙盒
        session_id = create_session_id()
        self._print(f"\n=== ▶️ 开始题目: {label} | session={session_id[:8]} ===")

        # 预写文件（可选）
        prewrites = item.get("prewrites")
        if isinstance(prewrites, list):
            for it in prewrites:
                fn = it.get("filename"); fpath = it.get("fpath", ""); is_bin = bool(it.get("binary", False))
                if isinstance(fn, str) and fn:
                    try:
                        with open(fpath, 'rb') as f:
                            content = f.read()
                            sandbox_write_file(session_id, fn, base64.b64encode(content).decode("ascii"), binary=True)
                        self._print(f"📥 写入文件: {fn} ({'binary' if is_bin else 'text/url'})")
                    except Exception as we:
                        self._print(f"⚠️ 写入失败 {fn}: {we}")

        # 题面
        formatted_question = item.get("formatted_question") or item.get("question")
        if not isinstance(formatted_question, str) or not formatted_question.strip():
            destroy_sandbox(session_id=session_id)
            return {"index": index, "task_id": task_id, "session_id": session_id, "status": "error", "error": "formatted_question/question missing", "question": ""}

        # 初始化messages: system + user
        self.messages = []
        if self._sys_prompt:
            self.messages.append({"role": "system", "content": self._sys_prompt})
        self.messages.append({"role": "user", "content": formatted_question})
        self._print("\n📨 题目内容：", formatted_question)

        iterations = 0
        final_model_output = ""
        last_iteration_output = ""
        iteration_traces: List[Dict[str, Any]] = []
        
        # Token统计信息
        total_input_tokens = 0
        total_output_tokens = 0
        total_tokens = 0
        iteration_token_stats: List[Dict[str, Any]] = []

        try:
            while True:
                iterations += 1
                if iterations > max_iterations:
                    iteration_traces.append({
                        "iteration": iterations,
                        "messages": self.messages.copy(),
                        "tool_calls": [],
                        "tool_returns": [],
                        "token_stats": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
                    })
                    return {
                        "index": index, "task_id": task_id, "session_id": session_id,
                        "status": "max_iter_exceeded", "iterations": iterations,
                        "final_model_output": final_model_output, "model_response": last_iteration_output,
                        "final_answer": item.get("answer"), "iteration_traces": iteration_traces,
                        "question": formatted_question,
                        "token_stats": {
                            "total_input_tokens": total_input_tokens,
                            "total_output_tokens": total_output_tokens,
                            "total_tokens": total_tokens,
                            "iteration_token_stats": iteration_token_stats
                        }
                    }

                llm_text, token_stats = self._call_llm()
                last_iteration_output = llm_text
                
                # 检查是否是空输出（API调用失败的情况）
                if not llm_text.strip() and token_stats["total_tokens"] == 0:
                    self._print(f"❌ 第{iterations}轮模型输出为空且token统计为0，API调用失败，终止任务")
                    iteration_traces.append({
                        "iteration": iterations,
                        "messages": self.messages.copy(),
                        "tool_calls": [], "tool_returns": [],
                        "token_stats": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
                    })
                    return {
                        "index": index, "task_id": task_id, "session_id": session_id,
                        "status": "api_error", "iterations": iterations,
                        "error": "LLM API调用重试{max_iterations}次后仍然失败",
                        "final_model_output": final_model_output, "model_response": last_iteration_output,
                        "final_answer": item.get("answer"), "iteration_traces": iteration_traces,
                        "question": formatted_question,
                        "token_stats": {
                            "total_input_tokens": total_input_tokens,
                            "total_output_tokens": total_output_tokens,
                            "total_tokens": total_tokens,
                            "iteration_token_stats": iteration_token_stats
                        }
                    }
                
                # 更新token统计
                total_input_tokens += token_stats["input_tokens"]
                total_output_tokens += token_stats["output_tokens"]
                total_tokens += token_stats["total_tokens"]
                
                # 记录本轮token统计
                iteration_token_stats.append({
                    "iteration": iterations,
                    "input_tokens": token_stats["input_tokens"],
                    "output_tokens": token_stats["output_tokens"],
                    "total_tokens": token_stats["total_tokens"]
                })
                
                # 将assistant的响应添加到messages中
                self.messages.append({"role": "assistant", "content": llm_text})
                final_model_output = llm_text

                # 检查是否有工具调用
                tool_call_str = self._extract_tool_calls(llm_text)
                if not tool_call_str:
                    # 没有工具调用，直接结束
                    iteration_traces.append({
                        "iteration": iterations,
                        "messages": self.messages.copy(),
                        "tool_calls": [],
                        "tool_returns": [],
                        "token_stats": token_stats
                    })
                    return {
                        "index": index, "task_id": task_id, "session_id": session_id,
                        "status": "success", "iterations": iterations,
                        "final_model_output": final_model_output, "model_response": last_iteration_output,
                        "final_answer": item.get("answer"), "iteration_traces": iteration_traces,
                        "question": formatted_question,
                        "token_stats": {
                            "total_input_tokens": total_input_tokens,
                            "total_output_tokens": total_output_tokens,
                            "total_tokens": total_tokens,
                            "iteration_token_stats": iteration_token_stats
                        }
                    }

                # 解析工具调用
                calls = parse_tool_calls(llm_text)
                if not calls:
                    self._print(f"⚠️ 工具调用解析失败")
                    # 工具调用解析失败，也结束
                    iteration_traces.append({
                        "iteration": iterations,
                        "messages": self.messages.copy(),
                        "tool_calls": [],
                        "tool_returns": [],
                        "token_stats": token_stats
                    })
                    return {
                        "index": index, "task_id": task_id, "session_id": session_id,
                        "status": "success", "iterations": iterations,
                        "final_model_output": final_model_output, "model_response": last_iteration_output,
                        "final_answer": item.get("answer"), "iteration_traces": iteration_traces,
                        "question": formatted_question,
                        "token_stats": {
                            "total_input_tokens": total_input_tokens,
                            "total_output_tokens": total_output_tokens,
                            "total_tokens": total_tokens,
                            "iteration_token_stats": iteration_token_stats
                        }
                    }

                # 执行工具调用
                tool_returns_records = self._execute_tools_from_calls(session_id, calls)
                
                # 将工具返回结果以role: tool的形式添加到messages中
                # 如果只有一个工具返回，content 是单个对象；如果有多个，是数组
                if len(tool_returns_records) == 1:
                    tool_return_content = json.dumps(tool_returns_records[0], ensure_ascii=False)
                else:
                    tool_return_content = json.dumps(tool_returns_records, ensure_ascii=False)
                self.messages.append({"role": "tool", "content": tool_return_content})

                iteration_traces.append({
                    "iteration": iterations,
                    "messages": self.messages.copy(),
                    "tool_calls": calls,
                    "tool_returns": tool_returns_records,
                    "token_stats": token_stats
                })

                # 继续下一轮生成

        finally:
            try:
                destroy_sandbox(session_id=session_id)
            except Exception:
                pass

# ----------------- IO helpers -----------------
def load_questions(path: str) -> List[Dict[str, Any]]:
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    if isinstance(data, list):
        return data
    if isinstance(data, dict):
        return [data]
    raise ValueError("Unsupported JSON root type, expect list or object.")

def get_output_path(path: str, model_name: str) -> str:
    """计算最终的输出路径"""
    os.makedirs(os.path.dirname(path), exist_ok=True) if os.path.dirname(path) else None
    # 保存名字为vllm_result/model_name/input_model_results.json
    dir_path = os.path.join(os.path.dirname(path), "vllm_result", model_name)
    os.makedirs(dir_path, exist_ok=True)
    final_path = os.path.join(dir_path, os.path.basename(path))
    final_path = final_path.replace(".json", f"_results.json")
    return final_path

def save_results_incremental(output_path: str, results: List[Dict[str, Any]], lock: threading.Lock):
    """增量保存结果（线程安全）"""
    with lock:
        # 将None替换成占位符，保持索引顺序一致
        results_to_save = []
        for i, r in enumerate(results):
            if r is None:
                # 未完成的任务用占位符表示
                results_to_save.append({
                    "index": i,
                    "status": "pending",
                    "message": "任务进行中..."
                })
            else:
                results_to_save.append(r)
        
        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(results_to_save, f, ensure_ascii=False, indent=2)

def save_results(path: str, results: List[Dict[str, Any]], model_name: str, already_final: bool = False):
    """
    最终保存结果
    - already_final=False（默认，兼容旧逻辑）：根据输入路径和模型名构造带 vllm_result/模型名/_results 后缀的最终路径
    - already_final=True：直接使用传入的 path 作为最终输出路径，不再自动拼目录/后缀
    """
    if already_final:
        output_path = path
    else:
        output_path = get_output_path(path, model_name)
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    print(f"\n💾 结果已保存到: {output_path}")

# ----------------- 批量并行 -----------------
def run_one(index: int, item: Dict[str, Any], model_name: str, max_iter: int, debug: bool) -> Dict[str, Any]:
    runner = SingleQuestionRunner(model_name=model_name, debug=debug)
    try:
        res = runner.run_question(item, index=index, max_iterations=max_iter)
        res["status"] = res.get("status") or "success"
        return res
    except Exception as e:
        # 兜底错误：确保至少返回 index / task_id
        return {
            "index": index,
            "task_id": item.get("task_id"),
            "session_id": res.get("session_id") if isinstance(res := {}) else "",
            "status": "error",
            "error": f"{type(e).__name__}: {e}",
            "question": item.get("formatted_question") or item.get("question", ""),
        }

# ----------------- 随机种子设置 -----------------
def set_random_seed(seed: Optional[int]):
    """设置随机种子以确保实验可重复性"""
    if seed is not None:
        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        try:
            import numpy as np
            np.random.seed(seed)
        except ImportError:
            pass
        try:
            import torch
            torch.manual_seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)
        except ImportError:
            pass
        print(f"🎲 随机种子已设置为: {seed}")

# ----------------- CLI -----------------
def main():
    ap = argparse.ArgumentParser(description="VLLM API沙盒工具调用 - 并行处理问题")
    ap.add_argument("-i", "--input", default="/path/to/project/data/extracted_bench-v2.json", help="输入 JSON 文件路径")
    ap.add_argument("-o", "--output", default="/root/gaia/vllm_result/ReTool-Qwen-32B/extracted_bench-v2_results.json", help="输出 JSON 文件路径（默认同名 _results.json）")
    ap.add_argument("-m", "--model", default="ReTool-Qwen-32B", help="模型名称（传给 VLLM API）")
    ap.add_argument("--max-iter", type=int, default=15, help="每题最大迭代轮次（默认10）")
    ap.add_argument("-w", "--workers", type=int, default=4, help="并行线程数（默认50）")
    ap.add_argument("--debug", action="store_true", help="打印每题详细日志（并行时会交错输出，谨慎使用）")
    ap.add_argument("--seed", type=int, default=None, help="随机种子（用于确保实验可重复性）")
    args = ap.parse_args()
    
    # 设置随机种子
    set_random_seed(args.seed)
    
    questions = load_questions(args.input)

    # out_path：用于最终保存时的“逻辑路径”
    # - 如果用户显式传了 -o，则直接用用户给的路径
    # - 否则默认在 input 的基础上加 _results.json
    out_path = args.output or (os.path.splitext(args.input)[0] + "_results.json")
    
    # final_output_path：增量保存时实际写入的路径
    # - 用户传了 -o：直接用这个路径（完全尊重用户）
    # - 未传 -o：沿用旧逻辑，自动拼 vllm_result/模型名/_results.json
    if args.output:
        final_output_path = args.output
    else:
        final_output_path = get_output_path(out_path, args.model)
    
    # 确保输出路径的目录存在
    output_dir = os.path.dirname(final_output_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
        print(f"📁 已创建输出目录: {output_dir}")

    total = len(questions)
    print(f"🚀 并行运行：共 {total} 题，workers={args.workers}, max_iter={args.max_iter}")
    print(f"🤖 当前模型：{args.model}")
    print(f"📁 数据集：{args.input}")
    print(f"💾 结果将保存到：{final_output_path}")

    results: List[Dict[str, Any]] = [None] * total  # 按索引位置回填，保证输出顺序与输入一致
    done = 0
    save_lock = threading.Lock()  # 用于保护文件写入操作

    with ThreadPoolExecutor(max_workers=args.workers) as ex:
        fut2idx = {ex.submit(run_one, i, q, args.model, args.max_iter, args.debug): i for i, q in enumerate(questions)}
        for fut in as_completed(fut2idx):
            i = fut2idx[fut]
            try:
                res = fut.result()
            except Exception as e:
                res = {"index": i, "task_id": questions[i].get("task_id"), "status": "error", "error": f"{type(e).__name__}: {e}", "question": questions[i].get("formatted_question") or questions[i].get("question", "")}
            results[i] = res
            done += 1
            ok = res.get("status") == "success"
            tid = res.get("task_id") or f"Q{i}"
            print(f"✅ 完成 {done}/{total}: {tid}  —  {'OK' if ok else 'FAIL'}")
            
            # 每完成一个任务就保存一次
            save_results_incremental(final_output_path, results, save_lock)

    # 最后再保存一次完整结果（确保所有结果都已保存）
    # - 如果用户显式传了 -o，则 out_path 已经是最终路径，直接写入
    # - 否则保持旧行为，由 save_results 内部拼接目录和后缀
    save_results(out_path, results, args.model, already_final=bool(args.output))
    succ = sum(1 for r in results if isinstance(r, dict) and r.get("status") == "success")
    
    # 统计token使用情况
    total_input_tokens = 0
    total_output_tokens = 0
    total_tokens = 0
    total_iterations = 0
    
    for r in results:
        if isinstance(r, dict) and r.get("token_stats"):
            token_stats = r["token_stats"]
            total_input_tokens += token_stats.get("total_input_tokens", 0)
            total_output_tokens += token_stats.get("total_output_tokens", 0)
            total_tokens += token_stats.get("total_tokens", 0)
        if isinstance(r, dict) and r.get("iterations"):
            total_iterations += r["iterations"]
    
    avg_iterations = total_iterations / total if total > 0 else 0
    
    print(f"\n🎉 结束：成功 {succ}/{total}，输出：{out_path}")
    print(f"📊 Token统计：输入={total_input_tokens:,}, 输出={total_output_tokens:,}, 总计={total_tokens:,}")
    print(f"🔄 迭代统计：总轮次={total_iterations}, 平均轮次={avg_iterations:.2f}")
    print(f"💡 注意：输出token统计的是模型实际生成的内容，工具返回结果通过role:tool单独添加，不计入模型输出")

if __name__ == "__main__":
    main()
