#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
API 推理模块（并行 & 工具调用）
- 仅两种工具：code（Python执行）、terminal（Shell执行），工具内部连到 sandbox_livehouse v2
- 每题/每任务“独立沙盒会话”：Create → 使用 → Destroy（线程本地记录 session_id）
- 标准 function calling 流程：assistant -> tool_calls -> tool(messages) -> assistant ...
- 结果增强：
  * conversation_history（含 assistant / tool_call）
  * final_model_output：按你要求“单 assistant 累积 + 每轮紧接真实 <tool_returns>”的**离线重建串**
  * last_iteration_output：最后一轮 assistant 的原始文本（未拼 returns）
"""

import json
import os
import time
import argparse
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from datetime import datetime
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from openai import AzureOpenAI, OpenAI, DefaultHttpxClient
import yaml
import base64
from fileagent_libs.fileagent_mcp_host import create_sandbox, destroy_sandbox, sandbox_write_file


# 工具分发 & 线程本地 session_id
from tool_utils import tool_dispatcher, set_session_id, clear_session_id

# 沙盒会话生命周期
from fileagent_libs.fileagent_mcp_host import create_sandbox, destroy_sandbox


# ---------------- 配置 & 客户端 ----------------
def build_api_client(base_url: str, api_version: str, api_key: str, model_config: Optional[Dict[str, Any]] = None) -> AzureOpenAI | OpenAI:
    """
    根据模型配置创建客户端：
    - ark / vllm / 纯 OpenAI 兼容服务使用 OpenAI 客户端
    - 其他走 AzureOpenAI（默认逻辑保持不变）
    """
    model_config = model_config or {}

    # vllm 场景禁用代理，避免走环境代理导致连接失败
    disable_proxy = model_config.get("disable_proxy", False)
    if disable_proxy:
        for k in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"):
            os.environ.pop(k, None)

    client_timeout = model_config.get("client_timeout", 60)
    http_client = DefaultHttpxClient(trust_env=not disable_proxy, timeout=client_timeout)

    use_openai_client = (
        model_config.get("client_type") == "openai"
        or model_config.get("is_vllm", False)
        or model_config.get("use_openai_client", False)
        or "ark" in base_url  # 保持原有 ark 判断
        or base_url.rstrip("/").endswith("/v1")  # 常见 vllm OpenAI 兼容路径
    )

    if use_openai_client:
        return OpenAI(
            api_key=api_key,
            base_url=base_url,
            http_client=http_client,
        )

    return AzureOpenAI(
        azure_endpoint=base_url,
        api_version=api_version,
        api_key=api_key,
    )


def load_api_config(model_name: str) -> Dict[str, Any]:
    try:
        with open("api_config.yaml", 'r', encoding='utf-8') as f:
            config = yaml.safe_load(f)
        defaults = config.get('defaults', {})
        models = config.get('models', {})
        if model_name not in models:
            print(f"警告: 模型 '{model_name}' 在配置文件中未找到，使用默认配置")
            return defaults
        model_config = defaults.copy()
        model_config.update(models[model_name])
        return model_config
    except Exception as e:
        print(f"加载模型配置失败: {e}")
        return {}


def load_tool_config() -> List[Dict[str, Any]]:
    """只需要两种工具（code / terminal）"""
    try:
        with open("tool_config.yaml", 'r', encoding='utf-8') as f:
            config = yaml.safe_load(f)
            print(config)
        return config.get('tools', [])
    except Exception as e:
        print(f"加载工具配置失败: {e}")
        return []


# ---------------- 重试 & thinking ----------------
def is_rate_limit_error(error_message: str) -> bool:
    rate_limit_indicators = [
        "429", "rate limit", "rate_limit", "ratelimit",
        "too many requests", "quota exceeded",
        "throttled", "throttling", "限流", "频率限制",
        "请求过于频繁", "concurrent requests"
    ]
    return any(ind in error_message.lower() for ind in rate_limit_indicators)


def is_context_length_error(error_message: str) -> bool:
    """检测是否为上下文长度超限错误（400错误码）"""
    context_length_indicators = [
        "400",
    ]
    return any(ind in error_message.lower() for ind in context_length_indicators)


def exponential_backoff(attempt: int, base_delay: float = 1.0, max_delay: float = 60.0) -> float:
    delay = min(base_delay * (2 ** attempt), max_delay)
    jitter = delay * 0.1 * (time.time() % 1)
    return delay + jitter


def build_thinking_config(model_name: str, model_config: Dict[str, Any], thinking_mode: bool) -> Dict[str, Any]:
    if not thinking_mode or not model_config.get("thinking_enabled", True):
        return {}
    thinking_budget_tokens = model_config.get("thinking_budget_tokens", 4096)
    if "claude" in model_name.lower():
        return {
            "extra_body": {
                "thinking": {"type": "enabled", "budget_tokens": thinking_budget_tokens},
                "anthropic_beta": ["interleaved-thinking-2025-05-14"]
            }
        }
    elif "gemini" in model_name.lower():
        return {"extra_body": {"thinking": {"include_thoughts": True, "budget_tokens": thinking_budget_tokens}}}
    elif "gpt" in model_name.lower():
        return {
            "reasoning_effort": "medium",
            "extra_body": {"reasoning": {"effort": "medium"}}
        }
    elif "ep-20250826141533-cf9bk" in model_name.lower():
        return {
            "extra_body": {"thinking": {"type": "enabled"}}
        }
    return {}


def extract_thinking_info(completion: Any, model_name: str) -> str:
    if not hasattr(completion, 'choices') or not completion.choices:
        return ""
    choice = completion.choices[0].message
    if hasattr(choice, 'reasoning_content'):
        return choice.reasoning_content
    elif hasattr(choice, 'thinking'):
        return choice.thinking
    elif hasattr(choice, 'reasoning'):
        return choice.reasoning
    return ""


# ---------------- 消息构造/工具回填 ----------------
def process_tool_calls(tool_calls: List[Any], model_name: str) -> Tuple[List[Dict], List[Dict]]:
    """
    执行工具调用，返回：
    - tool_messages: 供 API 消息的 {"role":"tool","tool_call_id":..., "content": "..."}
    - conversation_updates: 供 history 记录的 {"role":"tool_call","name":..,"args":..,"content": "..."}
    """
    tool_messages = []
    conversation_updates = []
    for tool_call in tool_calls:
        tool_call_dict = {
            "name": tool_call.function.name,
            "arguments": tool_call.function.arguments
        }
        tool_result = tool_dispatcher(tool_call_dict)  # 字符串（一般是 JSON 字符串）
        tool_messages.append({
            "role": "tool",
            "tool_call_id": tool_call.id,
            "content": tool_result
        })
        conversation_updates.append({
            "role": "tool_call",
            "name": tool_call.function.name,
            "args": tool_call.function.arguments,
            "content": tool_result
        })
    return tool_messages, conversation_updates


def build_assistant_message(completion: Any, tool_calls: List[Any], current_thinking: str, model_name: str) -> Dict[str, Any]:
    message = {
        "role": "assistant",
        "content": completion.choices[0].message.content,
        "tool_calls": [{
            "id": tc.id,
            "type": "function",
            "function": {"name": tc.function.name, "arguments": tc.function.arguments}
        } for tc in tool_calls]
    }
    if "claude" in model_name.lower():
        if hasattr(completion.choices[0].message, "signature"):
            message["signature"] = completion.choices[0].message.signature
        message["reasoning_content"] = current_thinking
    return message


# ---------------- 结果/重建串 ----------------
def rebuild_final_assistant(conversation_history: List[Dict]) -> str:
    """
    把对话历史重建为“单 assistant 串”，规则：
    - 每遇到 assistant.content 追加到串
    - 紧跟其后的所有 tool_call（这轮工具结果）合并为一个 <tool_returns>[...] 追加
    """
    buf: List[str] = []
    i = 0
    n = len(conversation_history)
    while i < n:
        m = conversation_history[i]
        role = m.get("role")
        if role == "assistant":
            buf.append(m.get("content") or "")
            # 聚合后续连续的 tool_call
            j = i + 1
            tool_rets = []
            while j < n and (conversation_history[j].get("role") == "tool_call"):
                tc = conversation_history[j]
                raw_res = tc.get("content", "")
                try:
                    parsed = json.loads(raw_res)  # 工具返回通常是 JSON 字符串
                except Exception:
                    parsed = raw_res
                tool_rets.append({
                    "name": tc.get("name"),
                    "args": tc.get("args"),
                    "result": parsed
                })
                j += 1
            if tool_rets:
                buf.append("<tool_returns>" + json.dumps(tool_rets, ensure_ascii=False) + "</tool_returns>")
            i = j
        else:
            i += 1
    return "\n".join(s for s in buf if s).strip()


def get_last_assistant_output(conversation_history: List[Dict], error_msg: str = "", round_count: int = 0, max_iterations: int = 10) -> str:
    """
    获取最后一次 assistant 输出，如果为空则提供有意义的默认信息
    """
    last = ""
    last_assistant_found = False
    
    for m in conversation_history:
        if m.get("role") == "assistant":
            last_assistant_found = True
            content = m.get("content") or ""
            if content.strip():  # 如果有实际内容
                last = content
            else:  # 如果内容为空，但有工具调用
                tool_calls = m.get("tool_calls", [])
                if tool_calls:
                    last = f"[工具调用模式] 模型进行了 {len(tool_calls)} 个工具调用，未生成文本回复"
                else:
                    last = "[空回复] 模型未生成任何内容"
    
    # 如果没有找到任何 assistant 消息
    if not last_assistant_found:
        if error_msg:
            if "最大迭代轮次限制" in error_msg:
                return f"[达到最大迭代次数] 在第 {round_count}/{max_iterations} 轮时达到限制，API 调用可能正常但未完成对话"
            elif "rate limit" in error_msg.lower() or "限流" in error_msg:
                return f"[API 限流错误] {error_msg}"
            else:
                return f"[API 调用失败] {error_msg}"
        else:
            return "[无回复] 未收到任何 assistant 消息"
    
    return last


def create_inference_result(success: bool, model_name: str, start_time: float,
                            conversation_history: List[Dict], thinking_mode: bool,
                            tools: List[Dict], attempt: int, total_retry_delay: float,
                            completion: Any = None, current_thinking: str = "",
                            error_msg: str = "", round_count: int = 0,
                            messages: List[Dict] = None, iteration_count: int = 0, 
                            max_iterations: int = 10) -> Dict[str, Any]:
    total_time = time.time() - start_time
    response_data: Dict[str, Any] = {
        "success": success,
        "model_name": model_name,
        "response_time": round(total_time, 2),
        "timestamp": datetime.now().isoformat(),
        "conversation_history": conversation_history,
        "thinking_enabled": thinking_mode,
        "tools_enabled": bool(tools),
        "retry_count": attempt,
        "retry_delay": total_retry_delay,
        "iteration_count": iteration_count,
    }
    if success:
        if completion and hasattr(completion, 'choices') and completion.choices:
            response_text = completion.choices[0].message.content
            response_data["response"] = response_text
            conversation_history.append({
                "role": "assistant",
                "content": response_text,
                "thinking": current_thinking
            })
        else:
            response_data["response"] = None
        response_data["error"] = None
        response_data["tool_calls_count"] = len([m for m in conversation_history if m.get("role") == "tool_call"]) if tools else 0
        # 增强字段：重建最终串 & 最后一轮assistant原文
        response_data["final_model_output"] = rebuild_final_assistant(conversation_history)
        response_data["last_iteration_output"] = get_last_assistant_output(conversation_history, error_msg="", round_count=iteration_count)
    else:
        response_data["response"] = None
        response_data["error"] = error_msg
        response_data["is_rate_limit"] = is_rate_limit_error(error_msg)
        response_data["is_context_length_error"] = is_context_length_error(error_msg)
        response_data["tool_calls_count"] = len([m for m in (messages or []) if m.get("role") == "tool_call"]) if tools else 0
        response_data["last_completed_round"] = round_count - 1 if round_count > 0 else 0
        # 即使失败，也给出尽量完整的重建/最后一轮
        try:
            response_data["final_model_output"] = rebuild_final_assistant(conversation_history)
            response_data["last_iteration_output"] = get_last_assistant_output(conversation_history, error_msg=error_msg, round_count=round_count, max_iterations=max_iterations)
        except Exception:
            response_data["last_iteration_output"] = f"[处理失败] 无法获取最后输出，错误: {error_msg}"
    return response_data


# ---------------- vLLM 兼容工具调用解析（与 tool_vllm_client_sandbox 对齐） ----------------
RE_TOOL_BLOCK = re.compile(
    r"(?:<tool_calls>(?P<block>[\s\S]*?)</tool_calls>|<tool_call>(?P<block2>[\s\S]*?)</tool_call>)",
    re.IGNORECASE
)


def _normalize_tool_call(call: Dict[str, Any]) -> Dict[str, Any]:
    """统一字段名，支持 tool_vllm_client_sandbox 的格式"""
    normalized = {}
    if "tool_name" in call:
        normalized["tool_name"] = call["tool_name"]
    elif "name" in call:
        normalized["tool_name"] = call["name"]
    if "input_parameters" in call:
        normalized["input_parameters"] = call["input_parameters"]
    elif "arguments" in call:
        normalized["input_parameters"] = call["arguments"]
    else:
        normalized["input_parameters"] = {}
    return normalized


def _find_matching_brace(text: str, start_pos: int) -> int:
    brace_count = 0
    in_string = False
    escape_next = False
    for i in range(start_pos, len(text)):
        char = text[i]
        if escape_next:
            escape_next = False
            continue
        if char == "\\":
            escape_next = True
            continue
        if char == '"' and not escape_next:
            in_string = not in_string
            continue
        if not in_string:
            if char == "{":
                brace_count += 1
            elif char == "}":
                brace_count -= 1
                if brace_count == 0:
                    return i
    return -1


def _parse_single_tool_call(text: str) -> List[Dict[str, Any]]:
    name_match = re.search(r'"tool_name":\s*"([^"]+)"', text) or re.search(r'"name":\s*"([^"]+)"', text)
    if not name_match:
        return []
    tool_name = name_match.group(1)
    params_match = re.search(r'"input_parameters":\s*(\{.*\})', text, re.DOTALL) or re.search(r'"arguments":\s*(\{.*\})', text, re.DOTALL)
    params = {}
    if params_match:
        params_str = params_match.group(1)
        try:
            params = json.loads(params_str)
        except Exception:
            try:
                import ast
                params = ast.literal_eval(params_str)
            except Exception:
                params = {}
    return [{"tool_name": tool_name, "input_parameters": params}]


def _split_json_objects(text: str) -> List[str]:
    parts = []
    current_part = ""
    brace_count = 0
    in_string = False
    escape_next = False
    for char in text:
        if escape_next:
            current_part += char
            escape_next = False
            continue
        if char == "\\":
            escape_next = True
            current_part += char
            continue
        if char == '"' and not escape_next:
            in_string = not in_string
        if not in_string:
            if char == "{":
                brace_count += 1
            elif char == "}":
                brace_count -= 1
        current_part += char
        if not in_string and brace_count == 0 and current_part.strip():
            if current_part.strip().endswith(","):
                current_part = current_part.rstrip(",").strip()
            parts.append(current_part.strip())
            current_part = ""
    if current_part.strip():
        parts.append(current_part.strip())
    return parts


def _manual_parse_tool_call(text: str) -> List[Dict[str, Any]]:
    result_calls: List[Dict[str, Any]] = []
    if text.strip().startswith("[") and text.strip().endswith("]"):
        try:
            inner_text = text.strip()[1:-1].strip()
            for part in _split_json_objects(inner_text):
                part = part.strip()
                if part.startswith("{") and part.endswith("}"):
                    result_calls.extend(_parse_single_tool_call(part))
            if result_calls:
                return result_calls
        except Exception:
            pass
    for pattern in (r'\{"tool_name":', r'\{"name":'):
        for match in re.finditer(pattern, text):
            start_pos = match.start()
            end_pos = _find_matching_brace(text, start_pos)
            if end_pos != -1:
                full_tool_text = text[start_pos:end_pos + 1]
                parsed_calls = _parse_single_tool_call(full_tool_text)
                result_calls.extend(parsed_calls)
    return result_calls


def parse_vllm_tool_calls(used_output: str) -> List[Dict[str, Any]]:
    """提取 <tool_calls>/<tool_call> 内的 JSON 并规范化"""
    m = RE_TOOL_BLOCK.search(used_output or "")
    if not m:
        return []
    block = (m.group("block") or m.group("block2") or "").strip()
    try:
        obj = json.loads(block)
        calls = obj if isinstance(obj, list) else [obj]
        return [_normalize_tool_call(call) for call in calls]
    except Exception:
        pass
    return _manual_parse_tool_call(block)


# ---------------- 推理主循环 ----------------
def handle_tool_call_round(completion: Any, tool_calls: List[Any], current_thinking: str,
                           model_name: str, conversation_history: List[Dict],
                           messages: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
    # 记录 assistant 消息（带 tool_calls 摘要）
    conversation_history.append({
        "role": "assistant",
        "content": completion.choices[0].message.content,
        "tool_calls": [{
            "id": tc.id, "type": "function",
            "function": {"name": tc.function.name, "arguments": tc.function.arguments}
        } for tc in tool_calls],
        "thinking": current_thinking
    })
    # 执行工具
    tool_messages, conversation_updates = process_tool_calls(tool_calls, model_name)
    conversation_history.extend(conversation_updates)
    # 构建 assistant（带 tool_calls）+ tool role 消息，追加进 messages 继续推理
    assistant_message = {
        "role": "assistant",
        "content": completion.choices[0].message.content,
        "tool_calls": [{
            "id": tc.id, "type": "function",
            "function": {"name": tc.function.name, "arguments": tc.function.arguments}
        } for tc in tool_calls]
    }
    messages.append(assistant_message)
    messages.extend(tool_messages)
    return tool_messages, conversation_updates


def handle_retry_logic(error_msg: str, current_round_retry: int, max_retries: int,
                       round_count: int, total_retry_delay: float) -> Tuple[bool, float, int, int]:
    # 检查是否为上下文长度超限错误，如果是则不重试
    if is_context_length_error(error_msg):
        return False, total_retry_delay, round_count, current_round_retry

    is_rate_limit = is_rate_limit_error(error_msg)
    should_retry_round = (current_round_retry < max_retries) or is_rate_limit
    if should_retry_round:
        current_round_retry += 1
        base_delay = 2.0 if is_rate_limit else 1.0
        max_delay = 120.0 if is_rate_limit else 30.0
        delay = exponential_backoff(current_round_retry - 1, base_delay, max_delay)
        total_retry_delay += delay
        time.sleep(delay)
        round_count -= 1  # 重试同一轮
    return should_retry_round, total_retry_delay, round_count, current_round_retry


def chat_with_thinking_toolcall(client: AzureOpenAI | OpenAI, model_name: str, prompt: str,
                                tools: List[Dict[str, Any]], max_retries: int = 3,
                                timeout: int = 60, thinking_mode: bool = True, max_iterations: int = 10) -> Dict[str, Any]:
    model_config = load_api_config(model_name)
    use_plain_content = bool(
        model_config.get("use_plain_content", False)
        or model_config.get("is_vllm", False)
        or model_config.get("client_type") == "openai"
    )
    use_vllm_toolcall = bool(model_config.get("use_vllm_toolcall_format", False))
    if use_vllm_toolcall:
        use_plain_content = True  # 与 tool_vllm_client_sandbox 对齐
    max_tokens = model_config.get("max_tokens", 16000)
    temperature = model_config.get("temperature", 1.0)
    api_params: Dict[str, Any] = {
        "model": model_name,
        "max_tokens": max_tokens,
        "temperature": temperature,
        **build_thinking_config(model_name, model_config, thinking_mode)
    }
    if tools and not model_config.get("is_ark_bot", False) and not use_vllm_toolcall:
        api_params["tools"] = tools

    start_time = time.time()
    attempt = 0
    total_retry_delay = 0
    round_count = 0
    current_round_retry = 0

    if use_plain_content:
        messages = [{"role": "user", "content": prompt}]
    else:
        messages = [{"role": "user", "content": [{"text": prompt, "type": "text"}]}]
    conversation_history = [{"role": "user", "content": prompt}]

    while True:
        try:
            round_count += 1
            
            # 检查是否超过最大迭代轮次
            if round_count > max_iterations:
                return create_inference_result(
                    success=False, model_name=model_name, start_time=start_time,
                    conversation_history=conversation_history, thinking_mode=thinking_mode,
                    tools=tools, attempt=attempt, total_retry_delay=total_retry_delay,
                    error_msg=f"达到最大迭代轮次限制: {max_iterations}", round_count=round_count, messages=messages,
                    iteration_count=round_count, max_iterations=max_iterations
                )
            completion = client.chat.completions.create(**api_params, messages=messages)
            current_thinking = extract_thinking_info(completion, model_name)

            if use_vllm_toolcall:
                content = completion.choices[0].message.content or ""
                conversation_history.append({
                    "role": "assistant",
                    "content": content,
                    "thinking": current_thinking
                })
                tool_calls = parse_vllm_tool_calls(content)
                if tool_calls:
                    tool_messages = []
                    for idx, call in enumerate(tool_calls):
                        tool_name = call.get("tool_name") or call.get("name")
                        args = call.get("input_parameters") or call.get("arguments") or {}
                        tool_result = tool_dispatcher({"name": tool_name, "arguments": args})
                        conversation_history.append({
                            "role": "tool_call",
                            "name": tool_name,
                            "args": args,
                            "content": tool_result
                        })
                        tool_messages.append({"role": "tool", "content": tool_result})
                    messages.append({"role": "assistant", "content": content})
                    messages.extend(tool_messages)
                    current_round_retry = 0
                    continue

                # 没有工具调用，返回结果
                return create_inference_result(
                    success=True, completion=completion, model_name=model_name,
                    start_time=start_time, conversation_history=conversation_history,
                    thinking_mode=thinking_mode, tools=tools, attempt=attempt,
                    total_retry_delay=total_retry_delay, current_thinking=current_thinking,
                    iteration_count=round_count, max_iterations=max_iterations
                )

            if tools and completion.choices[0].message.tool_calls:
                tool_calls = completion.choices[0].message.tool_calls
                handle_tool_call_round(
                    completion, tool_calls, current_thinking, model_name,
                    conversation_history, messages
                )
                current_round_retry = 0
                continue

            # 无工具 → 成功结束
            return create_inference_result(
                success=True, completion=completion, model_name=model_name,
                start_time=start_time, conversation_history=conversation_history,
                thinking_mode=thinking_mode, tools=tools, attempt=attempt,
                total_retry_delay=total_retry_delay, current_thinking=current_thinking,
                iteration_count=round_count, max_iterations=max_iterations
            )

        except Exception as e:
            error_msg = str(e)
            
            # 检查是否为上下文长度超限错误，如果是则直接返回友好错误信息
            if is_context_length_error(error_msg):
                friendly_error_msg = "模型回复超长"
                return create_inference_result(
                    success=False, model_name=model_name, start_time=start_time,
                    conversation_history=conversation_history, thinking_mode=thinking_mode,
                    tools=tools, attempt=attempt, total_retry_delay=total_retry_delay,
                    error_msg=friendly_error_msg, round_count=round_count, messages=messages,
                    iteration_count=round_count, max_iterations=max_iterations
                )
            
            should_retry_round, total_retry_delay, round_count, current_round_retry = handle_retry_logic(
                error_msg, current_round_retry, max_retries, round_count, total_retry_delay
            )
            if should_retry_round:
                print(f"模型 {model_name} 第 {round_count + 1} 轮第 {current_round_retry} 次尝试，错误: {error_msg}")
                continue

            return create_inference_result(
                success=False, model_name=model_name, start_time=start_time,
                conversation_history=conversation_history, thinking_mode=thinking_mode,
                tools=tools, attempt=attempt, total_retry_delay=total_retry_delay,
                error_msg=error_msg, round_count=round_count + 1, messages=messages,
                iteration_count=round_count + 1, max_iterations=max_iterations
            )


def _write_prewrites_to_sandbox(session_id: str, prewrites: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    支持四种来源：
      - {"filename": "a.json", "content": "<文本>"}                -> 文本写入
      - {"filename": "a.json", "content": "https://..."}         -> 沙盒自动下载（text/url）
      - {"filename": "a.bin",  "content": "<base64>", "binary": true} -> 二进制(base64)写入
      - {"filename": "a.docx", "fpath": "/host/path/file", "binary": true} -> 读取本地文件并 base64
    返回每个条目的写入结果摘要，便于排障记录到 result 中。
    """
    results: List[Dict[str, Any]] = []
    for it in prewrites or []:
        fn = it.get("filename")
        if not fn:
            results.append({"filename": None, "ok": False, "err": "missing filename"})
            continue
        if "fpath" in it and isinstance(it["fpath"], str):
            fpath = it["fpath"]
            is_bin = bool(it.get("binary", False))
            fn = it.get("filename");
            if isinstance(fn, str) and fn:
                try:
                    with open(fpath, 'rb') as f:
                        content = f.read()
                        sandbox_write_file(session_id, fn, base64.b64encode(content).decode("ascii"), binary=True)
                        results.append({"filename": fn, "ok": True, "via": "fpath:rb->b64", "binary": True})
                    print(f"📥 写入文件: {fn} ({'binary' if is_bin else 'text/url'})")
                except Exception as we:
                    print(f"⚠️ 写入失败 {fn}: {we}")

                
    return results



# ---------------- 并行封装：一题一沙盒 ----------------
def process_single_task(task_data: Tuple[int, str, str, Dict[str, Any], AzureOpenAI | OpenAI, List[Dict[str, Any]], int, bool, int]) -> Dict[str, Any]:
    """
    每个任务创建独立沙盒：Create -> set_session_id(thread-local) -> (预写文件) -> 推理 -> Destroy
    """
    row_idx, model_name, prompt, base_data, client, tools, max_retries, thinking_mode, max_iterations = task_data

    # Create sandbox
    sess_obj = create_sandbox()
    try:
        sess_payload = sess_obj.as_dict().get("return")
        if isinstance(sess_payload, str):
            session_id = json.loads(sess_payload)["session_id"]
        elif isinstance(sess_payload, dict):
            session_id = sess_payload["session_id"]
        else:
            session_id = json.loads(sess_payload)["session_id"]
    except Exception as e:
        raise RuntimeError(f"CreateSandbox 解析失败: {e}")

    # 绑定线程上下文，供 tool_utils 取用
    set_session_id(session_id)

    prewrite_status: List[Dict[str, Any]] = []
    try:
        # ✅ 预写文件：兼容旧格式 prewrites（filename + fpath/binary）以及 content/url
        if isinstance(base_data.get("prewrites"), list):
            prewrite_status = _write_prewrites_to_sandbox(session_id, base_data["prewrites"])

        # 调用对话/工具链
        result = chat_with_thinking_toolcall(
            client, model_name, prompt, tools, max_retries, thinking_mode=thinking_mode, max_iterations=max_iterations
        )

        # 附加元数据 & 沙盒信息 & 预写结果
        result.update(base_data)
        result['row_idx'] = row_idx
        result['session_id'] = session_id
        if prewrite_status:
            result['prewrite_status'] = prewrite_status
        return result
    finally:
        clear_session_id()
        try:
            destroy_sandbox(session_id=session_id)
        except Exception:
            pass


def batch_infer(infer_data: List[Dict[str, Any]], model_name: str, base_url: str, api_version: str,
                api_key: str, max_retries: int = 3, max_workers: int = 1,
                delay_between_requests: float = 0.0, thinking_mode: bool = True,
                tool_mode: bool = True, save_results: bool = True,
                output_dir: str = "inference_results", max_iterations: int = 10) -> List[Dict[str, Any]]:

    model_config = load_api_config(model_name)
    if not base_url:
        base_url = model_config.get("azure_endpoint", "")
    client = build_api_client(base_url, api_version, api_key, model_config)
    tools = load_tool_config() if tool_mode else []

    # 输出
    output_file = None
    if save_results:
        Path(output_dir).mkdir(exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file = Path(output_dir) / f"inference_{model_name}_{timestamp}.jsonl"
        print(f"结果将保存到: {output_file}")

    # 任务准备
    tasks = []
    for i, input_item in enumerate(infer_data):
        base_data = input_item.copy()
        base_data["batch_index"] = i
        prompt = base_data["formatted_question"]
        tasks.append((i, model_name, prompt, base_data, client, tools, max_retries, thinking_mode, max_iterations))

    results: List[Dict[str, Any]] = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_task = {executor.submit(process_single_task, task): task for task in tasks}
        for future in tqdm(as_completed(future_to_task), total=len(tasks), desc="推理进度"):
            try:
                result = future.result()
            except Exception as e:
                task = future_to_task[future]
                result = {
                    "success": False,
                    "error": str(e),
                    "row_idx": task[0],
                    "model_name": task[1],
                    "prompt": task[2],
                    "full_input_data": task[3],
                    "timestamp": datetime.now().isoformat()
                }
            results.append(result)

            if save_results and output_file:
                with open(output_file, 'a', encoding='utf-8') as f:
                    f.write(json.dumps(result, ensure_ascii=False) + '\n')

            if delay_between_requests and delay_between_requests > 0:
                time.sleep(delay_between_requests)

    results.sort(key=lambda x: x.get('row_idx', 0))
    return results


def batch_infer_from_file(input_path: str,
                          model_name: str,
                          base_url: str,
                          api_version: str,
                          api_key: str,
                          max_retries: int = 3,
                          max_workers: int = 8,
                          delay_between_requests: float = 0.0,
                          thinking_mode: bool = True,
                          tool_mode: bool = True,
                          save_results: bool = True,
                          output_dir: str = "inference_results",
                          max_iterations: int = 10) -> List[Dict[str, Any]]:
    """
    从 JSON/JSONL 文件加载题目，批量推理。
    支持输入格式：
      - JSON array: [ {task_id, level, formatted_question, prewrites, answer}, ... ]
      - JSON object: {task_id, level, formatted_question, prewrites, answer}
      - JSONL: 每行一个 JSON 对象
    """
    # 1) 读取输入
    data: List[Dict[str, Any]] = []
    if input_path.endswith(".jsonl"):
        with open(input_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                data.append(json.loads(line))
    else:
        with open(input_path, "r", encoding="utf-8") as f:
            obj = json.load(f)
            if isinstance(obj, list):
                data = obj
            elif isinstance(obj, dict):
                data = [obj]
            else:
                raise ValueError("输入 JSON 必须是对象或数组")

    print(f"✅ 从 {input_path} 加载 {len(data)} 道题")

    # 2) 调用原有 batch_infer
    return batch_infer(
        infer_data=data,
        model_name=model_name,
        base_url=base_url,
        api_version=api_version,
        api_key=api_key,
        max_retries=max_retries,
        max_workers=max_workers,
        delay_between_requests=delay_between_requests,
        thinking_mode=thinking_mode,
        tool_mode=tool_mode,
        save_results=save_results,
        output_dir=output_dir,
        max_iterations=max_iterations
    )


if __name__ == "__main__":
    # 命令行参数解析
    parser = argparse.ArgumentParser(description="批量推理脚本")
    parser.add_argument("--input", "-i", type=str, default="/path/to/project/data/extracted_bench-v2.json", help="输入文件路径 (JSON/JSONL)")
    parser.add_argument("--output", "-o", type=str, default="/root/sand_box/extracted_bench-v2", help="输出目录路径")
    parser.add_argument("--max_workers", type=int, default=4, help="最大并发数 (默认: 3)")
    parser.add_argument("--max_retries", type=int, default=10, help="最大重试次数 (默认: 10)")
    parser.add_argument("--delay", type=float, default=2.0, help="请求间隔延迟(秒) (默认: 2.0)")
    parser.add_argument("--max_iterations", type=int, default=15, help="最大迭代轮次 (默认: 15)")
    parser.add_argument("--models_config", type=str, default="/root/sand_box/models_config.json", help="模型配置JSON文件路径或JSON字符串 (默认使用硬编码配置)")
    
    args = parser.parse_args()
    
    # 加载模型配置
    if args.models_config:
        # 尝试作为文件路径读取
        if os.path.exists(args.models_config):
            with open(args.models_config, 'r', encoding='utf-8') as f:
                test_config_list = json.load(f)
            print(f"✅ 从文件加载模型配置: {args.models_config}")
        else:
            # 尝试作为JSON字符串解析
            try:
                test_config_list = json.loads(args.models_config)
                print(f"✅ 从JSON字符串加载模型配置")
            except json.JSONDecodeError as e:
                print(f"❌ 模型配置解析失败: {e}")
                raise
    else:
        # 使用默认硬编码配置
        test_config_list = [
            # {
            #     "model_name": "gemini-2.5-pro-preview-05-06",
            #     "thinking_mode": True,
            #     "tool_mode": True,
            # },
            # {
            #     "model_name": "gpt-4.1-2025-04-14",
            #     "thinking_mode": False,
            #     "tool_mode": True,
            # },
            # {
            #     "model_name": "gpt-oss-120b",
            #     "thinking_mode": True,
            #     "tool_mode": True,
            # }
            # {
            #     "model_name": "grok-3",
            #     "thinking_mode": True,
            #     "tool_mode": True,
            # }
            # {
            #     "model_name": "gcp-claude4.1-opus",
            #     "thinking_mode": False,
            #     "tool_mode": True,
            # },
            # {
            #     "model_name": "gcp-claude4-opus",
            #     "thinking_mode": False,
            #     "tool_mode": True,
            # },
            # {
            #     "model_name": "gpt-5-mini-2025-08-07",
            #     "thinking_mode": False,
            #     "tool_mode": True,
            # }
            # {
            #     "model_name": "o3-2025-04-16",
            #     "thinking_mode": False,
            #     "tool_mode": True,
            # }
            # {
            #     "model_name": "ep-20250328185550-g5zcd",
            #     "thinking_mode": False,
            #     "tool_mode": True,
            # },
            # {
            #     "model_name": "ep-20250826141533-cf9bk", #deepseekv3.1
            #     "thinking_mode": True,
            #     "tool_mode": True,
            # },
            {
                "model_name": "glm-4.6",
                "thinking_mode": True,
                "tool_mode": True,
            },
            # {
            #     "model_name": "gpt-5-2025-08-07",
            #     "thinking_mode": False,
            #     "tool_mode": True,
            # },
            # {
            #     "model_name": "gcp-claude37-sonnet",
            #     "thinking_mode": False,
            #     "tool_mode": True,
            # },
            # {
            #     "model_name": "gpt-4.1-mini-2025-04-14",
            #     "thinking_mode": False,
            #     "tool_mode": True,
            # },
            # {
            #     "model_name": "o3-2025-04-16",
            #     "thinking_mode": False,
            #     "tool_mode": True,
            # },
            # {
            #     "model_name": "bot-20250515142114-xmfh2",
            #     "thinking_mode": True,
            #     "tool_mode": True,
            # },
        ]
    
    input_path = args.input
    output_dir = args.output
    
    for config in test_config_list:
        model_name = config["model_name"]
        model_config = load_api_config(model_name)
        base_url = model_config.get("azure_endpoint", "")
        api_key = model_config.get("api_key", "")
        api_version = model_config.get("api_version", "2024-03-01-preview")

        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        results = batch_infer_from_file(
            input_path=input_path,
            model_name=model_name,
            base_url=base_url,
            api_key=api_key,
            api_version=api_version,
            max_retries=args.max_retries,
            max_workers=args.max_workers,
            delay_between_requests=args.delay,
            thinking_mode=config["thinking_mode"],
            tool_mode=config["tool_mode"],
            save_results=True,
            output_dir=output_dir,
            max_iterations=args.max_iterations
        )
        

    print("结果：")
    for r in results:
        print(json.dumps(r, ensure_ascii=False))
        print("-" * 80)
    
    # 全部推理完成后输出总结信息
    print("\n" + "="*80)
    print("🎯 推理任务完成总结")
    print("="*80)
    print(f"📁 输入文件路径: {input_path}")
    print(f"📁 输出文件路径: {output_dir}")
    print(f"🤖 使用的模型: {', '.join([config['model_name'] for config in test_config_list])}")
    print(f"📊 总处理题目数: {len(results) if 'results' in locals() else 0}")
    print(f"⏰ 完成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print("="*80)
