# EXAMPLE COMMAND: from folder examples/open_deep_research, run: python run_gaia.py --concurrency 32 --run-name generate-traces-03-apr-noplanning --model-id gpt-4o
import argparse
import json
import os
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from pathlib import Path
from typing import Any

import datasets
import pandas as pd
from dotenv import load_dotenv
from huggingface_hub import login, snapshot_download
from scripts.reformulator import prepare_response
from scripts.run_agents import (
    get_single_file_description,
    get_zip_description,
)
from scripts.text_inspector_tool import TextInspectorTool
from scripts.text_web_browser import (
    ArchiveSearchTool,
    FinderTool,
    FindNextTool,
    PageDownTool,
    PageUpTool,
    SimpleTextBrowser,
    VisitTool,
)
from scripts.visual_qa import visualizer
from tqdm import tqdm

from smolagents import (
    CodeAgent,
    GoogleSearchTool,
    LiteLLMModel,
    Model,
    ToolCallingAgent,
    SupervisorAgent,
    SupervisorKBManager,
)

from phoenix.otel import register
from openinference.instrumentation.smolagents import SmolagentsInstrumentor

from smolagents.agents import RunResult
from smolagents.memory import FinalAnswerStep, ActionStep, TaskStep

from smolagents.models import MonitoredModel    # 用于全局监督 token cost
import re
import traceback


load_dotenv(override=True)
login(os.getenv("HF_TOKEN"))

append_answer_lock = threading.Lock()


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--concurrency", type=int, default=1)
    parser.add_argument("--num-examples", type=int, default=None)
    parser.add_argument("--model-id", type=str, default="gpt-4.1")
    parser.add_argument("--model-id-summary", type=str, default="gpt-4.1")
    parser.add_argument("--run-name", type=str, required=True)
    parser.add_argument("--set-to-run", type=str, default="validation")
    parser.add_argument("--use-open-models", type=bool, default=False)
    parser.add_argument("--use-raw-dataset", action="store_true")
    parser.add_argument("--supervisor", action="store_true")
    parser.add_argument("--task-id", type=str, default=None, help="Only run the task with this specific ID")
    parser.add_argument("--level", type=int, default=None, help="Only run tasks with this specific level")
    return parser.parse_args()


### IMPORTANT: EVALUATION SWITCHES

print("Make sure you deactivated any VPN like Tailscale, else some URLs will be blocked!")

custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"}


user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"

BROWSER_CONFIG = {
    "viewport_size": 1024 * 5,
    "downloads_folder": "downloads_folder",
    "request_kwargs": {
        "headers": {"User-Agent": user_agent},
        "timeout": 300,
    },
    "serpapi_key": os.getenv("SERPAPI_API_KEY"),
}

os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True)

class TokenUsage:
    def __init__(self, input_tokens=0, output_tokens=0):
        self.input_tokens = input_tokens
        self.output_tokens = output_tokens

def calculate_and_log_token_usage(
    agent_dict: dict[str, any],
    monitored_model: MonitoredModel,
    prep_token_usage: TokenUsage,
    supervisor: bool
) -> dict:
    """
    计算并打印所有来源的token消耗，并返回一个用于存储的最终字典。

    Args:
        agent_dict (dict): 包含所有Agent实例的字典。
        monitored_model (MonitoredModel): 被包装的模型实例。
        prep_token_usage (TokenUsage): 从 prepare_response 返回的token消耗。
        supervisor (bool): 是否使用了supervisor模式。

    Returns:
        dict: 一个包含总消耗和supervisor消耗的字典，用于最终存储。
    """
    # --- Agent 级别的外部消耗 ---
    agent_total_input = 0
    agent_total_output = 0
    supervisor_total_tokens = 0
    summary_total_usage = 0

    # 1. Manager Agent
    manager_usage = agent_dict["manager"].monitor.get_total_token_counts()
    agent_total_input += manager_usage.input_tokens
    agent_total_output += manager_usage.output_tokens
    print(f"Manager agent token usage - Input: {manager_usage.input_tokens}, Output: {manager_usage.output_tokens}")

    # 2. Web Search Agent
    web_search_usage = agent_dict["web_search"].monitor.get_total_token_counts()
    agent_total_input += web_search_usage.input_tokens
    agent_total_output += web_search_usage.output_tokens
    print(f"Web Search agent token usage - Input: {web_search_usage.input_tokens}, Output: {web_search_usage.output_tokens}")

    # 3. Supervisor & Verification Agents
    supervisor_input = 0
    supervisor_output = 0
    if supervisor:
        supervisor_usage = agent_dict["supervisor"].monitor.get_total_token_counts()
        # verification_usage = agent_dict["verification"].monitor.get_total_token_counts()
        supervisor_input = supervisor_usage.input_tokens # + verification_usage.input_tokens
        supervisor_output = supervisor_usage.output_tokens # + verification_usage.output_tokens
        # agent_total_input += supervisor_input
        # agent_total_output += supervisor_output
        supervisor_total_tokens = supervisor_input + supervisor_output
        print(f"Supervisor/Verification token usage - Input: {supervisor_input}, Output: {supervisor_output}")
    
    # 4. prepare_response 步骤的消耗
    agent_total_input += prep_token_usage.input_tokens + supervisor_input
    agent_total_output += prep_token_usage.output_tokens + supervisor_output
    print(f"Preparation step token usage - Input: {prep_token_usage.input_tokens}, Output: {prep_token_usage.output_tokens}")

    # 5. Summary Model
    if "summary_model" in agent_dict:
        summary_model = agent_dict["summary_model"]
        summary_usage = summary_model.total_token_usage
        agent_total_input += summary_usage.input_tokens
        agent_total_output += summary_usage.output_tokens
        summary_total_usage = summary_usage.input_tokens + summary_usage.output_tokens
        print(f"Summary model token usage - Input: {summary_usage.input_tokens}, Output: {summary_usage.output_tokens}")
    
    if "text_inspector" in agent_dict:
        ti_tool = agent_dict["text_inspector"]
        ti_usage = ti_tool.get_total_token_counts()
        agent_total_input += ti_usage.input_tokens
        agent_total_output += ti_usage.output_tokens
        print(f"Text Inspector Tool token usage - Input: {ti_usage.input_tokens}, Output: {ti_usage.output_tokens}")

    agent_total_tokens = agent_total_input + agent_total_output

    # --- MonitoredModel 统计的内部消耗 ---
    internal_usage = monitored_model.total_token_usage
    internal_total_tokens = internal_usage.input_tokens + internal_usage.output_tokens
    print(f"Model-level (internal) total token usage - Input: {internal_usage.input_tokens}, Output: {internal_usage.output_tokens}\n")

    # --- 计算最终总和 ---
    grand_total_tokens = agent_total_tokens + internal_total_tokens
    print(f"Grand total token usage (all sources) = {grand_total_tokens}\n")
    
    # 返回用于存储的字典
    return {
        "total_token_count": grand_total_tokens,
        "supervisor_token": supervisor_total_tokens + summary_total_usage,
    }

def create_supervisor_callback(supervisor: SupervisorAgent, example: dict):
    """
    创建一个回调函数，该函数会根据条件选择性地调用 supervisor，
    并精确传递局部和全局任务上下文。
    
    Args:
        supervisor (SupervisorAgent): 监督代理实例。
        example (dict): 当前正在运行的 GAIA 数据集条目。
    """
    # parent_agent_map 是一个简化的方式来追踪调用关系
    # 在实际框架中，您可能需要一个更稳健的调用栈来管理
    parent_agent_map = {"search_agent": "manager_agent"}

    def supervisor_callback(memory_step: ActionStep, agent: CodeAgent):
        if isinstance(memory_step, ActionStep):
            global_task = example["question"]   # 全局任务始终是原始问题
            parent_agent_name = parent_agent_map.get(agent.name)

            if agent.name == "manager_agent":
                clean_local_task = ""   # 防止global和local重复
            else:
                clean_local_task = extract_clean_task(agent.task)   # 原始局部任务提取
            supervisor.record_step(
                step=memory_step,
                agent_name=agent.name,
                # memory=agent.memory,
                local_task=clean_local_task,
                parent_agent_name=parent_agent_name
            )

            # 先检查 observation 是否包含隐性错误
            # is_implicit_error = False
            # if memory_step.observations and memory_step.tool_calls[0].name != "final_answer":
            #     obs_lower = memory_step.observations.lower()
            #     # 检查是否存在代码执行失败的明确信号
            #     if "code execution failed" in obs_lower or "traceback (most recent call last)" in obs_lower or "error executing tool" in obs_lower:
            #         is_implicit_error = True
            #         # 如果这是一个隐性错误，我们强制将其“显性化”，以确保 error_analysis 能够被触发
            #         if memory_step.error is None:
            #             memory_step.error = memory_step.observations # 将错误文本直接赋给 error 属性

            should_supervise = False

            # # 条件1：错误处理（最高优先级）
            # if is_implicit_error or memory_step.error is not None:
            #     # 增加防御性检查，修复我们之前发现的KeyError Bug
            #     if (memory_step.tool_calls and 
            #         "find_archived_url" not in memory_step.observations):
            #         should_supervise = True
            #         print("--- [Supervisor] Detected explicit error or implicit error with tool_call, forcing supervision. ---")
            #     elif not memory_step.tool_calls: # 如果是代码执行错误，通常没有tool_call，必须监督
            #         should_supervise = True
            #         print("--- [Supervisor] Detected implicit error without tool_call, forcing supervision. ---")
            #     else:
            #         print(f"--- [Supervisor] Ignoring a known/safe error from 'find_archived_url'. ---")    # 忽略 find_archived_url 的可预知失败
            
            # 优先级1：处理错误
            if memory_step.error is not None:
                # 默认情况下，只要有错误就应该监督
                should_supervise = True
                
                # **唯一的例外**：如果错误是由 'find_archived_url' 工具引起的，我们选择忽略
                # 防御性地检查 tool_calls 是否存在且非空
                if memory_step.tool_calls and memory_step.tool_calls[0].name == "find_archived_url":
                    should_supervise = False # 覆盖为不监督
                    print(f"--- [Supervisor] Ignoring a known/safe error from 'find_archived_url'. ---")
            
            # 仅在没有错误时，才检查其他条件
            if not should_supervise:
                if (memory_step.observations and len(memory_step.observations) > 3000):
                    should_supervise = True
                elif supervisor._check_for_inefficiency(agent.name)[0]:
                    should_supervise = True
                elif (memory_step.observations and "### 1. Task outcome (short version):" in memory_step.observations):
                    should_supervise = True

            # 尽量保持原始 web_search 的返回值
            if memory_step.tool_calls and len(memory_step.tool_calls) > 0:
                if memory_step.tool_calls[0].name == "web_search" and len(memory_step.observations) < 5000:
                    should_supervise == None
                    print("--- [Supervisor] Skip web_search_tool supervision")

            is_final_answer_attempt = False
            proposed_answer = None

            # if agent.name == "manager_agent" and memory_step.tool_calls:
            #     if any(tool_call.name == "final_answer" for tool_call in memory_step.tool_calls):
            #         is_final_answer_attempt = True
            
            # if agent.name == "manager_agent" and memory_step.tool_calls:
            #     tool_call = memory_step.tool_calls[0]

            #     # 情况1: Manager 直接调用 final_answer 工具
            #     if tool_call.name == "final_answer":
            #         is_final_answer_attempt = True
            #         proposed_answer = tool_call.arguments.get('answer', '[Answer not found in arguments]')

            #     # 情况2: Manager 调用 python_interpreter，其代码中包含了 final_answer(...)
            #     elif tool_call.name == "python_interpreter":
            #         code_to_execute = tool_call.arguments
            #         if "final_answer(" in code_to_execute:
            #             is_final_answer_attempt = True
            #             # 使用正则表达式从代码中提取 final_answer 的内容
            #             # 这个模式会匹配 final_answer( ... ) 中最内层括号里的所有内容
            #             match = re.search(r"final_answer\((.*)\)", code_to_execute, re.DOTALL)
            #             if match:
            #                 # 提取括号内的内容，并去除可能的引号和空白
            #                 extracted = match.group(1).strip()
            #                 if (extracted.startswith('"') and extracted.endswith('"')) or \
            #                    (extracted.startswith("'") and extracted.endswith("'")):
            #                     proposed_answer = extracted[1:-1]
            #                 else:
            #                     proposed_answer = extracted
            #             else:
            #                 proposed_answer = "[Could not extract answer from code logic]"

            # is_final_answer_attempt = False # 关闭 final_supervision 通道

            # 过滤多次浏览的 URL
            if memory_step.observations and "## Search Results" in memory_step.observations:
                if memory_step.tool_calls[0].name != "final_answer":
                    original_observation = memory_step.observations
                    filtered_observation = supervisor.filter_search_results(original_observation)
                    # 如果过滤操作确实移除了内容，就地更新 observation
                    if original_observation != filtered_observation:
                        print("--- [Supervisor] Silently filtering observation without forcing API-based intervention. ---")
                        memory_step.observations = filtered_observation

            if should_supervise:    # or is_final_answer_attempt:
                parent_agent_name = parent_agent_name
                supervisor.supervise_and_correct(
                    step=memory_step, 
                    agent_name=agent.name, 
                    local_task=clean_local_task,
                    global_task=global_task,
                    is_final_check=is_final_answer_attempt,
                    proposed_answer = proposed_answer
                )
            else:
                print(f"--- [Supervisor] Skipped review for step {memory_step.step_number} of '{agent.name}'. ---")
            
    return supervisor_callback

def extract_clean_task(full_task_prompt: str) -> str:
    """
    从完整的增强指令中提取干净的核心任务描述。
    假定核心任务位于两个 '---' 分隔符之间。
    """
    try:
        # 1. 使用 '---' 分割字符串，这会将其分为三部分
        parts = full_task_prompt.split('---')
        if len(parts) >= 3:
            # 2. 取中间部分，即包含 "Task:" 和核心任务的部分
            task_block = parts[1]
            
            # 3. 清理前缀和首尾的空白字符
            # 首先移除 "Task:" (以及它前后的换行符)
            if "Task:" in task_block:
                task_content = task_block.split("Task:", 1)[1]
            else:
                task_content = task_block
            
            return task_content.strip()
        else:
            # 如果格式不匹配，返回原始任务以保证鲁棒性
            return full_task_prompt
    except Exception:
        # 发生任何异常时，安全地返回原始输入
        return full_task_prompt

def create_verification_agent(model: Model) -> CodeAgent:
    """
    创建一个轻量级的、为 Supervisor 优化的、能力全面的验证 agent。
    该 agent 基于 CodeAgent 以确保与被监督 agent 的能力对等。
    """
    browser = SimpleTextBrowser(**BROWSER_CONFIG)
    # 为验证 agent 设置更低的文本限制，以控制 token 消耗
    text_limit = 10000
    
    # 为 Supervisor Agent 配备包括 visualizer 在内的完整工具集，以确保能力对等
    VERIFICATION_TOOLS = [
        GoogleSearchTool(provider="serper"),
        VisitTool(browser),
        PageUpTool(browser),
        PageDownTool(browser),
        FinderTool(browser),
        FindNextTool(browser),
        ArchiveSearchTool(browser),
        TextInspectorTool(model, text_limit),
        visualizer, # 赋予 Supervisor 视觉能力
    ]

    # 之前一直用的 CodeAgent，先已改成 ToolCallingAgent
    verification_agent = ToolCallingAgent(
        model=model,
        tools=VERIFICATION_TOOLS,
        max_steps=8,  # 降低 max_steps 以控制成本
        verbosity_level=0,  # 关闭详细日志，保持 Supervisor 的输出干净
        name="verification_agent",
        description="A specialized agent for quick fact-checking and information verification.",
        # 允许所有 import，因为它需要能够执行其工具可能需要的任何代码
        # additional_authorized_imports=["*"], 
        # 可以在模型层面增加 token 上限，例如：
        # model=LiteLLMModel(model_id=model.model_id, max_tokens=1024)
    )

    # 优化 verification_agent 的 prompt，使其更专注于高效检索和验证
    verification_agent.prompt_templates["system_prompt"] += """
    You are a highly efficient fact-checking assistant. Your primary goal is to quickly and accurately verify, correct, or summarize information based on a given task.
    - Be precise and concise.
    - Avoid deep, exploratory searches.
    - Execute the task, find the specific information needed, and use the 'final_answer' tool immediately to return the result.
    - Your response should be direct and factual.
    Here are the tools you can use:
    """ + "\n".join([tool.to_code_prompt() for tool in verification_agent.tools.values()])
    
    return verification_agent


def create_agent_team(model: Model):
    text_limit = 100000
    ti_tool = TextInspectorTool(model, text_limit)

    browser = SimpleTextBrowser(**BROWSER_CONFIG)

    WEB_TOOLS = [
        GoogleSearchTool(provider="serper"),
        VisitTool(browser),
        PageUpTool(browser),
        PageDownTool(browser),
        FinderTool(browser),
        FindNextTool(browser),
        ArchiveSearchTool(browser),
        ti_tool,
    ]

    text_webbrowser_agent = ToolCallingAgent(
        model=model,
        tools=WEB_TOOLS,
        max_steps=20,
        verbosity_level=2,
        planning_interval=4,
        name="search_agent",
        description="""A team member that will search the internet to answer your question.
    Ask him for all your questions that require browsing the web.
    Provide him as much context as possible, in particular if you need to search on a specific timeframe!
    And don't hesitate to provide him with a complex search task, like finding a difference between two webpages.
    Your request must be a real sentence, not a google search! Like "Find me this information (...)" rather than a few keywords.
    """,
        provide_run_summary=True,
    )
    text_webbrowser_agent.prompt_templates["managed_agent"]["task"] += """You can navigate to .txt online files.
    If a non-html page is in another format, especially .pdf or a Youtube video, use tool 'inspect_file_as_text' to inspect it.
    Additionally, if after some searching you find out that you need more information to answer the question, you can use `final_answer` with your request for clarification as argument to request for more information."""

    manager_agent = CodeAgent(
        model=model,
        tools=[visualizer, ti_tool],
        max_steps=12,
        verbosity_level=2,
        additional_authorized_imports=["*"],
        planning_interval=4,
        managed_agents=[text_webbrowser_agent],
    )
    # return manager_agent
    return {
        "manager": manager_agent,
        "web_search": text_webbrowser_agent,
        "text_inspector": ti_tool
        }


def create_agent_team_with_supervisor(model: Model, summary_model: Model,  example: dict, kb_manager: SupervisorKBManager):
    # === 1. 创建专门用于 Supervisor 的、能力完备的 Verification Agent ===
    verification_agent = create_verification_agent(model)

    # === 2. 创建 Supervisor 并注入 Verification Agent ===
    supervisor = SupervisorAgent(model=model, summary_model=summary_model, verification_agent=verification_agent, kb_manager=kb_manager)
    # 使用 example 创建一个具体的回调实例
    supervisor_callback_instance = create_supervisor_callback(supervisor, example)    

    # === 3. 创建专门用于主任务的 Web Agent ===
    text_limit = 100000
    ti_tool = TextInspectorTool(model, text_limit)
    main_browser = SimpleTextBrowser(**BROWSER_CONFIG)
    MAIN_WEB_TOOLS = [
        GoogleSearchTool(provider="serper"),
        VisitTool(main_browser),
        PageUpTool(main_browser),
        PageDownTool(main_browser),
        FinderTool(main_browser),
        FindNextTool(main_browser),
        ArchiveSearchTool(main_browser),
        ti_tool,
    ]
    main_web_agent = ToolCallingAgent(
        model=model,
        tools=MAIN_WEB_TOOLS,
        max_steps=20,
        verbosity_level=2,
        planning_interval=4,
        name="search_agent",
        description="""A team member that will search the internet to answer your question.
    Ask him for all your questions that require browsing the web.
    Provide him as much context as possible, in particular if you need to search on a specific timeframe!
    And don't hesitate to provide him with a complex search task, like finding a difference between two webpages.
    Your request must be a real sentence, not a google search! Like "Find me this information (...)" rather than a few keywords.
    """,
        provide_run_summary=True,
        step_callbacks=[supervisor_callback_instance], # 关键：为 main_web_agent 添加回调
    )

    main_web_agent.prompt_templates["managed_agent"]["task"] += """You can navigate to .txt online files.
    If a non-html page is in another format, especially .pdf or a Youtube video, use tool 'inspect_file_as_text' to inspect it.
    Additionally, if after some searching you find out that you need more information to answer the question, you can use `final_answer` with your request for clarification as argument to request for more information."""

    # === 4. 创建顶层的 Manager Agent ===
    manager_agent = CodeAgent(
        model=model,
        tools=[visualizer, ti_tool], # Manager 自身也保留视觉能力
        max_steps=12,
        verbosity_level=2,
        additional_authorized_imports=["*"],
        planning_interval=4,
        managed_agents=[main_web_agent],
        step_callbacks=[supervisor_callback_instance],
        name="manager_agent",
    )
    
    # (可选) 如果你还想监督 main_web_agent 的内部步骤，也可以为其添加回调
    # main_web_agent.step_callbacks = {ActionStep: supervisor_callback_instance}
    
    # return manager_agent
    return {
        "manager": manager_agent,
        "supervisor": supervisor,
        "verification": verification_agent,
        "web_search": main_web_agent,
        "summary_model": summary_model,
        "text_inspector": ti_tool
    }

# def load_gaia_dataset(use_raw_dataset: bool, set_to_run: str) -> datasets.Dataset:
    if not os.path.exists("data/gaia"):
        if use_raw_dataset:
            snapshot_download(
                repo_id="gaia-benchmark/GAIA",
                repo_type="dataset",
                local_dir="data/gaia",
                ignore_patterns=[".gitattributes", "README.md"],
            )
        else:
            # WARNING: this dataset is gated: make sure you visit the repo to require access.
            snapshot_download(
                repo_id="smolagents/GAIA-annotated",
                repo_type="dataset",
                local_dir="data/gaia",
                ignore_patterns=[".gitattributes", "README.md"],
            )

    def preprocess_file_paths(row):
        if len(row["file_name"]) > 0:
            row["file_name"] = f"data/gaia/{set_to_run}/" + row["file_name"]
        return row

    eval_ds = datasets.load_dataset(
        "data/gaia/GAIA.py",
        name="2023_all",
        split=set_to_run,
        # data_files={"validation": "validation/metadata.jsonl", "test": "test/metadata.jsonl"},
    )

    eval_ds = eval_ds.rename_columns({"Question": "question", "Final answer": "true_answer", "Level": "task"})
    eval_ds = eval_ds.map(preprocess_file_paths)
    return eval_ds

def load_gaia_dataset(use_raw_dataset: bool, set_to_run: str, level: int | None = None) -> datasets.Dataset:
    # 1. 下载数据集
    if not os.path.exists("data/gaia"):
        repo_id = "gaia-benchmark/GAIA" if use_raw_dataset else "smolagents/GAIA-annotated"
        snapshot_download(
            repo_id=repo_id,
            repo_type="dataset",
            local_dir="data/gaia",
            ignore_patterns=[".gitattributes", "README.md"],
        )

    # 2. 构造 metadata 文件路径
    metadata_file = os.path.join("data/gaia", "2023", set_to_run, "metadata.jsonl")
    if not os.path.exists(metadata_file):
        raise FileNotFoundError(f"Metadata file not found: {metadata_file}")

    # 3. 读取 JSONL
    with open(metadata_file, "r") as f:
        data = [json.loads(line) for line in f]

    # 4. 可选 level 筛选
    if level is not None:
        data = [item for item in data if item.get("Level") == level]

    # 5. 文件路径预处理
    for item in data:
        if item.get("file_name"):
            item["file_name"] = os.path.join("data/gaia", "2023", set_to_run, item["file_name"])

    # 6. 列重命名
    for item in data:
        item["question"] = item.pop("Question", None)
        item["true_answer"] = item.pop("Final answer", None)
        item["task"] = item.pop("Level", None)

    # 7. 转成 Dataset
    eval_ds = datasets.Dataset.from_list(data)
    return eval_ds


def append_answer(entry: dict, jsonl_file: str) -> None:
    jsonl_path = Path(jsonl_file)
    jsonl_path.parent.mkdir(parents=True, exist_ok=True)
    with append_answer_lock, open(jsonl_file, "a", encoding="utf-8") as fp:
        fp.write(json.dumps(entry, ensure_ascii=False, indent=2) + "\n")
    assert jsonl_path.exists(), "File not found!"
    print("Answer exported to file:", jsonl_path.resolve())

# 把 intermediate_steps 里的 ChatMessage 对象都转成 dict
# added on 0825
def serialize_intermediate_steps(steps):
    serialized = []
    for step in steps:
        if hasattr(step, "content"):
            serialized.append({
                "role": getattr(step, "role", None),
                "content": step.content,
            })
        else:
            # 如果不是 ChatMessage，直接转成 str
            serialized.append(str(step))
    return serialized

# 优化最终答案提取逻辑
def extract_final_answer(result: any) -> str | None:
    """
    Robustly extracts the final answer from the various possible return types
    of the agent's run method.
    """
    final_answer = None

    # 1. 检查是否为 RunResult 对象 (当 return_full_result=True)
    if isinstance(result, RunResult):
        print("-> Result is a RunResult object. Extracting from .output attribute.")
        final_answer = result.output

    # 2. 检查是否为 FinalAnswerStep 对象 (这是一种不太可能但安全的检查)
    elif isinstance(result, FinalAnswerStep):
        print("-> Result is a FinalAnswerStep object. Extracting from .output attribute.")
        final_answer = result.output
        
    # 3. 检查是否为字符串
    elif isinstance(result, str):
        print("-> Result is a string. Parsing string for answer.")
        # 有可能字符串里包含 "Final answer:", 也可能就是纯净答案
        if "Final answer:" in result:
            final_answer = result.split("")[-1].strip()
        else:
            # 如果没有特定前缀，我们假设整个字符串就是答案
            final_answer = result.strip()
            
    # 如果 final_answer 仍然是 None 或空字符串，返回 None
    if final_answer and str(final_answer).strip():
        return str(final_answer).strip()
    
    return None

def answer_single_question(
    example: dict, model_id: str, model_id_summary, answers_file: str, visual_inspection_tool: TextInspectorTool, supervisor: bool = False, kb_manager: SupervisorKBManager = None
) -> None:
    prep_token_usage = TokenUsage()
    # try:    # 确保捕获所有问题，不让程序终止
    model_params: dict[str, Any] = {
        "model_id": model_id,
        "custom_role_conversions": custom_role_conversions,
    }
    model_params_summary: dict[str, Any] = {
        "model_id": model_id_summary,
        "custom_role_conversions": custom_role_conversions,
    }
    if model_id == "o1":
        model_params["reasoning_effort"] = "high"
        model_params["max_completion_tokens"] = 8192
    else:
        model_params["max_tokens"] = 4096
    model_raw = LiteLLMModel(**model_params)
    model = MonitoredModel(model_raw)
    model_raw_summary = LiteLLMModel(**model_params_summary)
    summary_model = MonitoredModel(model_raw_summary)

    # model = InferenceClientModel(model_id="Qwen/Qwen3-32B", provider="novita", max_tokens=4096)
    document_inspection_tool = TextInspectorTool(model, 100000)
    
    agent_dict: dict[str, any] = {}
    if supervisor:
        agent_dict = create_agent_team_with_supervisor(model, summary_model, example, kb_manager=kb_manager)
        agent = agent_dict["manager"]
        print("=== Using Supervisor Agent ===")
    else:
        agent_dict = create_agent_team(model)
        agent = agent_dict["manager"]
        print("=== Using Standard Agent Team ===")
    
    # agent.logger.visualize_agent_tree(agent)  输出 agent 结构

    augmented_question = """You have one question to answer. It is paramount that you provide a correct answer.
Give it all you can: I know for a fact that you have access to all the relevant tools to solve it and find the correct answer (the answer does exist).
Failure or 'I cannot answer' or 'None found' will not be tolerated, success will be rewarded.
Run verification steps if that's needed, you must make sure you find the correct answer! Here is the task:

""" + example["question"]

    if example["file_name"]:
        if ".zip" in example["file_name"]:
            prompt_use_files = "\n\nTo solve the task above, you will have to use these attached files:\n"
            prompt_use_files += get_zip_description(
                example["file_name"], example["question"], visual_inspection_tool, document_inspection_tool
            )
        else:
            prompt_use_files = "\n\nTo solve the task above, you will have to use this attached file:\n"
            prompt_use_files += get_single_file_description(
                example["file_name"], example["question"], visual_inspection_tool, document_inspection_tool
            )
        augmented_question += prompt_use_files

    start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    try: # 【把 try 语句移动到开头，确保捕获所有错误】
        # Run agent 🚀

        # 1. 运行 Agent 并获取原始输出
        raw_result = agent.run(augmented_question, return_full_result=True)
        raw_final_result = raw_result.output
        print(f"> Raw final result: {raw_final_result}")    # 调试 debugging
        agent_memory = agent.write_memory_to_messages()
        # output = extract_final_answer(raw_final_result)

        # # 如果直接提取失败，再调用 prepare_response 作为后备方案
        # if output is None:
        #     print("> Direct extraction failed, falling back to prepare_response.")
        #     # 这里的 prepare_response 仍然可能出错，但我们已经尽力了
        #     output, prep_token_usage, prep_prompt = prepare_response(augmented_question, agent_memory, reformulation_model=model)
        # else:
        #     print(f"> Successfully extracted answer: {output}")
        
        # debugging
        output, prep_token_usage, prep_prompt = prepare_response(augmented_question, agent_memory, reformulation_model=model_raw)

        # final_result = agent.run(augmented_question)
        # agent_memory = agent.write_memory_to_messages()
        # final_result = prepare_response(augmented_question, agent_memory, reformulation_model=model)
        # output = str(final_result)

        for memory_step in agent.memory.steps:
            memory_step.model_input_messages = None
        intermediate_steps = agent_memory

        # Check for parsing errors which indicate the LLM failed to follow the required format
        # parsing_error = True if any(["AgentParsingError" in step for step in intermediate_steps]) else False
        parsing_error = any(
            "AgentParsingError" in (step.content if hasattr(step, "content") else str(step))
            for step in intermediate_steps
        )


        # check if iteration limit exceeded
        iteration_limit_exceeded = True if "Agent stopped due to iteration limit or time limit." in output else False
        raised_exception = False

    except Exception as e:
        print("Error on ", augmented_question, e)
        print("------------------- TRACEBACK START -------------------")
        traceback.print_exc()
        print("-------------------- TRACEBACK END --------------------")
        output = None
        raw_final_result = "Failure"
        intermediate_steps = []
        parsing_error = False
        iteration_limit_exceeded = False
        exception = e
        raised_exception = True

    end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    total_token_counts = calculate_and_log_token_usage(
        agent_dict=agent_dict,
        monitored_model=model,
        prep_token_usage=prep_token_usage if 'prep_token_usage' in locals() else TokenUsage(), # 确保变量存在
        supervisor=supervisor
    )

    # Token COUNT
    # agent_total_token = 0
    
    # if supervisor:
    #     supervisor_input = 0
    #     supervisor_output = 0
    #     supervisor_usage = agent_dict["supervisor"].monitor.get_total_token_counts()
    #     supervisor_input = supervisor_usage.input_tokens
    #     supervisor_output = supervisor_usage.output_tokens
    #     supervisor_total_tokens = supervisor_input + supervisor_output
    #     print(f"\n> Supervisor/Verification token usage - Input: {supervisor_input}, Output: {supervisor_output}")
    #     agent_total_token += supervisor_total_tokens

    # if "summary_model" in agent_dict:
    #     summary_model = agent_dict["summary_model"]
    #     summary_usage = summary_model.total_token_usage
    #     agent_total_token = summary_usage.input_tokens + summary_usage.output_tokens
    #     supervisor_total_tokens += summary_usage.input_tokens + summary_usage.output_tokens
    #     print(f"> Summary model token usage - Input: {summary_usage.input_tokens}, Output: {summary_usage.output_tokens}")

    # manager_token = agent_dict["manager"].monitor.get_total_token_counts()
    # agent_total_token += manager_token.input_tokens + manager_token.output_tokens
    # print(f"> Manager usage - Input: {manager_token.input_tokens}, Output: {manager_token.output_tokens}")

    # web_search_usage = agent_dict["web_search"].monitor.get_total_token_counts()
    # agent_total_token += web_search_usage.input_tokens + web_search_usage.output_tokens
    # print(f"> Web Search agent token usage - Input: {web_search_usage.input_tokens}, Output: {web_search_usage.output_tokens}")

    # internal_token = raw_result.token_usage.total_tokens
    # agent_total_token += internal_token
    # print(f"> Internal total: {internal_token}")
    # internal_token_2 = model.total_token_usage.input_tokens + model.total_token_usage.output_tokens
    # print(f"> Internal total (model monitor): {internal_token_2}")

    # text_inspect_token = agent_dict["text_inspector"].get_total_token_counts()
    # text_inspect_token_total = text_inspect_token.input_tokens + text_inspect_token.output_tokens
    # agent_total_token += text_inspect_token_total
    # print(f"> Text Inspector Tool total:{text_inspect_token_total} - Input: {text_inspect_token.input_tokens}, Output: {text_inspect_token.output_tokens}")

    # prep_token_usage_total = prep_token_usage.input_tokens + prep_token_usage.output_tokens
    # agent_total_token += prep_token_usage_total
    # print(f"> Preparation step total: {prep_token_usage_total} - Input: {prep_token_usage.input_tokens}, Output: {prep_token_usage.output_tokens}")
    # print(f">>> Grand total token usage (all sources) = {agent_total_token}\n")

    annotated_example = {
        "agent_name": model.model_id,
        "question": example["question"],
        "augmented_question": augmented_question,
        "Level": example["task"],
        "raw_prediction": raw_final_result,
        "prediction": output,
        "true_answer": example["true_answer"],
        "token_counts": total_token_counts,
        # "intermediate_steps": intermediate_steps,
        "intermediate_steps": serialize_intermediate_steps(intermediate_steps),
        "parsing_error": parsing_error,
        "iteration_limit_exceeded": iteration_limit_exceeded,
        "agent_error": str(exception) if raised_exception else None,
        # "prep_prompt": serialize_intermediate_steps(prep_prompt),
        "task_id": example["task_id"],
        "start_time": start_time,
        "end_time": end_time,
    }
    append_answer(annotated_example, answers_file)


# def get_examples_to_answer(answers_file: str, eval_ds: datasets.Dataset) -> list[dict]:
    # print(f"Loading answers from {answers_file}...")
    # try:
    #     done_questions = pd.read_json(answers_file, lines=True)["question"].tolist()
    #     print(f"Found {len(done_questions)} previous results!")
    # except Exception as e:
    #     print("Error when loading records: ", e)
    #     print("No usable records! ▶️ Starting new.")
    #     done_questions = []
    # return [line for line in eval_ds.to_list() if line["question"] not in done_questions and line["file_name"]]

# def get_examples_to_answer(answers_file: str, eval_ds: datasets.Dataset, task_id: str = None) -> list[dict]:
#     print(f"Loading answers from {answers_file}...")
#     try:
#         done_questions = pd.read_json(answers_file, lines=True)["question"].tolist()
#         print(f"Found {len(done_questions)} previous results!")
#     except Exception as e:
#         print("Error when loading records: ", e)
#         print("No usable records! ▶️ Starting new.")
#         done_questions = []
    
#     # examples = [line for line in eval_ds.to_list() if line["question"] not in done_questions and line["file_name"]]
#     examples = [line for line in eval_ds.to_list() if line["question"] not in done_questions]

#     # 如果指定了 task_id，就只保留这一条
#     if task_id:
#         examples = [ex for ex in examples if str(ex.get("task_id")) == str(task_id)]
#         print(f"Filtered for task_id={task_id}, {len(examples)} tasks found.")
#     return examples

def load_jsonl_multiline(path: str):
    """支持多行格式的 jsonl 文件读取"""
    records = []
    with open(path, "r", encoding="utf-8") as f:
        buffer = ""
        for line in f:
            line = line.strip()
            if not line:
                continue
            buffer += line
            # 尝试解析
            try:
                record = json.loads(buffer)
                records.append(record)
                buffer = ""  # 清空，等待下一个对象
            except json.JSONDecodeError:
                # 说明还没凑成一个完整 JSON
                continue
    return records


def get_examples_to_answer(answers_file: str, eval_ds: datasets.Dataset, task_id: str = None) -> list[dict]:
    print(f"Loading answers from {answers_file}...")
    try:
        records = load_jsonl_multiline(answers_file)
        done_questions = [r["question"] for r in records if "question" in r]
        print(f"Found {len(done_questions)} previous results!")
    except Exception as e:
        print("Error when loading records: ", e)
        print("No usable records! ▶️ Starting new.")
        done_questions = []

    examples = [line for line in eval_ds.to_list() if line["question"] not in done_questions]

    # 如果指定了 task_id，就只保留这一条
    if task_id:
        examples = [ex for ex in examples if str(ex.get("task_id")) == str(task_id)]
        print(f"Filtered for task_id={task_id}, {len(examples)} tasks found.")
    return examples

def main():
    # Agent acticity logging
    register()
    SmolagentsInstrumentor().instrument()

    args = parse_args()
    print(f"Starting run with arguments: {args}")

    eval_ds = load_gaia_dataset(args.use_raw_dataset, args.set_to_run, args.level)
    print("Loaded evaluation dataset:")
    print(pd.DataFrame(eval_ds)["task"].value_counts())

    # 1. 在程序开始时，创建并初始化知识库管理器 (一次性操作)
    print("Initializing Knowledge Base Manager...")
    KB_FILE_PATH = "/home/ofo/project_workflow_auto_generation/smolagents/examples/open_deep_research/data/supervisor_database.json" # 您的经验库文件路径
    kb_manager = SupervisorKBManager(json_file_path=KB_FILE_PATH)
    print("Knowledge Base Manager is ready.")

    answers_file = f"output/{args.set_to_run}/{args.run_name}.jsonl"
    # tasks_to_run = get_examples_to_answer(answers_file, eval_ds)

    # 测试小规模指定id的样本
    # test_id_file = "/home/ofo/project_workflow_auto_generation/smolagents/examples/open_deep_research/data/supervisor_list/high_token_task_ids.jsonl"
    # with open (test_id_file, "r") as f:
    #     test_ids = [line.strip() for line in f.readlines() if line.strip()]
    # print(f"Loaded {len(test_ids)} test IDs from {test_id_file}")
    # test_ids = [json.loads(id).get("task_id") for id in test_ids]
    # tasks_to_run = []
    # for test_id in test_ids:
    #     ex = get_examples_to_answer(answers_file, eval_ds, test_id)
    #     if ex:
    #         tasks_to_run.append(ex[0])

    # 以下为临时注释
    if args.task_id:
        tasks_to_run = get_examples_to_answer(answers_file, eval_ds, args.task_id)
    else:
        tasks_to_run = get_examples_to_answer(answers_file, eval_ds)
        if args.num_examples:
            tasks_to_run = tasks_to_run[: args.num_examples]
        else:
            tasks_to_run = tasks_to_run

    with ThreadPoolExecutor(max_workers=args.concurrency) as exe:
        futures = [
            exe.submit(answer_single_question, example, args.model_id, args.model_id_summary, answers_file, visualizer, args.supervisor, kb_manager)
            for example in tasks_to_run
        ]
        for f in tqdm(as_completed(futures), total=len(tasks_to_run), desc="Processing tasks"):
            f.result()

    # for example in tasks_to_run:
    #     answer_single_question(example, args.model_id, answers_file, visualizer)
    print("All tasks processed.")


if __name__ == "__main__":
    main()
