# EXAMPLE COMMAND: from folder examples/open_deep_research, run: python run_gaia.py --concurrency 32 --run-name generate-traces-03-apr-noplanning --model-id gpt-4o
import argparse
import json
import os
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from pathlib import Path
from typing import Any

import datasets
import pandas as pd
from dotenv import load_dotenv
from huggingface_hub import login, snapshot_download
from scripts.reformulator import prepare_response
from scripts.run_agents import (
    get_single_file_description,
    get_zip_description,
)
from scripts.text_inspector_tool import TextInspectorTool
from scripts.text_web_browser import (
    ArchiveSearchTool,
    FinderTool,
    FindNextTool,
    PageDownTool,
    PageUpTool,
    SimpleTextBrowser,
    VisitTool,
)
from scripts.visual_qa import visualizer
from tqdm import tqdm

from smolagents import (
    CodeAgent,
    GoogleSearchTool,
    LiteLLMModel,
    Model,
    ToolCallingAgent,
    SupervisorAgent,
    SupervisorKBManager,
)

from phoenix.otel import register
from openinference.instrumentation.smolagents import SmolagentsInstrumentor

from smolagents.agents import RunResult
from smolagents.memory import FinalAnswerStep, ActionStep, TaskStep

from smolagents.models import MonitoredModel    # 用于全局监督 token cost
import re
import traceback
import random


load_dotenv(override=True)
login(os.getenv("HF_TOKEN"))
append_answer_lock = threading.Lock()
### IMPORTANT: EVALUATION SWITCHES
print("Make sure you deactivated any VPN like Tailscale, else some URLs will be blocked!")
custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"}
user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
BROWSER_CONFIG = {
    "viewport_size": 1024 * 5,
    "downloads_folder": "downloads_folder",
    "request_kwargs": {
        "headers": {"User-Agent": user_agent},
        "timeout": 300,
    },
    "serpapi_key": os.getenv("SERPAPI_API_KEY"),
}
os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True)


def compare_numbers_simple(input1, input2):
    import requests
    """简化版数值比较"""
    prompt = f"""# Numerical Comparison Task

Determine if two numerical expressions are mathematically equal. Consider different formats like fractions, decimals, LaTeX, percentages, and units.

**Rules:**
- Same numerical value = True
- Different units = False (e.g., "3 m" ≠ "3 cm")
- Ignore formatting differences

**Examples:**
- `1/2` vs `0.5` → True
- `\\frac{{3}}{{4}}` vs `75%` → True  
- `\\sqrt{{16}}` vs `4` → True
- `3.5 m` vs `3.5 cm` → False
- `2^3` vs `8` → True
- `π` vs `3.14159` → True

**Output format:** Just return `True` or `False`

**Compare:**
Input1: `{input1}`
Input2: `{input2}`

**Answer:**"""
    # 配置你的API信息
    load_dotenv()  # Load environment variables from .env file
    LLM_MODEL_NAME = os.environ.get("OPENAI_LLM_MODEL_NAME", "gpt-4.1")
    LLM_API_KEY = os.environ.get("OPENAI_LLM_API_KEY", "")
    LLM_BASE_URL = os.environ.get("OPENAI_LLM_BASE_URL", "")
    headers = {
        "Authorization": f"Bearer {LLM_API_KEY}",
        "Content-Type": "application/json"
    }

    # API call
    data = {
        "model": LLM_MODEL_NAME,
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": 50,
        "temperature": 0.1
    }
    
    response = requests.post(f"{LLM_BASE_URL}/chat/completions", 
                           json=data, headers=headers)
    
    if response.status_code == 200:
        result = response.json()['choices'][0]['message']['content'].strip()
        return result.lower() == 'true'
    return None


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--concurrency", type=int, default=1)
    parser.add_argument("--num-examples", type=int, default=None)
    parser.add_argument("--model-id", type=str, default="gpt-4.1")
    parser.add_argument("--local-model-path", type=str, default=None)
    parser.add_argument("--model-id-summary", type=str, default="gpt-4.1")
    parser.add_argument("--run-name", type=str, required=True)
    parser.add_argument("--set-to-run", type=str, default="validation")
    parser.add_argument("--dataset-path", type=str, default="./data/aime")
    # parser.add_argument("--use-open-models", type=bool, default=False)
    # parser.add_argument("--use-raw-dataset", action="store_true")
    parser.add_argument("--supervisor", action="store_true")
    
    parser.add_argument("--task-id", type=str, default=None, help="Only run the task with this specific ID")
    # parser.add_argument("--level", type=int, default=None, help="Only run tasks with this specific level")
    return parser.parse_args()


class TokenUsage:
    def __init__(self, input_tokens=0, output_tokens=0):
        self.input_tokens = input_tokens
        self.output_tokens = output_tokens


def calculate_and_log_token_usage(
    agent_dict: dict[str, any],
    monitored_model: MonitoredModel,
    prep_token_usage: TokenUsage,
    supervisor: bool
) -> dict:
    """
    计算并打印所有来源的token消耗，并返回一个用于存储的最终字典。

    Args:
        agent_dict (dict): 包含所有Agent实例的字典。
        monitored_model (MonitoredModel): 被包装的模型实例。
        prep_token_usage (TokenUsage): 从 prepare_response 返回的token消耗。
        supervisor (bool): 是否使用了supervisor模式。

    Returns:
        dict: 一个包含总消耗和supervisor消耗的字典，用于最终存储。
    """
    # --- Agent 级别的外部消耗 ---
    agent_total_input = 0
    agent_total_output = 0
    supervisor_total_tokens = 0
    summary_total_usage = 0

    # 1. Manager Agent
    manager_usage = agent_dict["manager"].monitor.get_total_token_counts()
    agent_total_input += manager_usage.input_tokens
    agent_total_output += manager_usage.output_tokens
    print(f"Manager agent token usage - Input: {manager_usage.input_tokens}, Output: {manager_usage.output_tokens}")

    # 2. Web Search Agent
    # web_search_usage = agent_dict["web_search"].monitor.get_total_token_counts()
    # agent_total_input += web_search_usage.input_tokens
    # agent_total_output += web_search_usage.output_tokens
    # print(f"Web Search agent token usage - Input: {web_search_usage.input_tokens}, Output: {web_search_usage.output_tokens}")

    # 3. Supervisor & Verification Agents
    supervisor_input = 0
    supervisor_output = 0
    if supervisor:
        supervisor_usage = agent_dict["supervisor"].monitor.get_total_token_counts()
        # verification_usage = agent_dict["verification"].monitor.get_total_token_counts()
        supervisor_input = supervisor_usage.input_tokens # + verification_usage.input_tokens
        supervisor_output = supervisor_usage.output_tokens # + verification_usage.output_tokens
        # agent_total_input += supervisor_input
        # agent_total_output += supervisor_output
        supervisor_total_tokens = supervisor_input + supervisor_output
        print(f"Supervisor/Verification token usage - Input: {supervisor_input}, Output: {supervisor_output}")
    
    # 4. prepare_response 步骤的消耗
    agent_total_input += prep_token_usage.input_tokens + supervisor_input
    agent_total_output += prep_token_usage.output_tokens + supervisor_output
    print(f"Preparation step token usage - Input: {prep_token_usage.input_tokens}, Output: {prep_token_usage.output_tokens}")

    # 5. Summary Model
    if "summary_model" in agent_dict:
        summary_model = agent_dict["summary_model"]
        summary_usage = summary_model.total_token_usage
        agent_total_input += summary_usage.input_tokens
        agent_total_output += summary_usage.output_tokens
        summary_total_usage = summary_usage.input_tokens + summary_usage.output_tokens
        print(f"Summary model token usage - Input: {summary_usage.input_tokens}, Output: {summary_usage.output_tokens}")
    
    if "text_inspector" in agent_dict:
        ti_tool = agent_dict["text_inspector"]
        ti_usage = ti_tool.get_total_token_counts()
        agent_total_input += ti_usage.input_tokens
        agent_total_output += ti_usage.output_tokens
        print(f"Text Inspector Tool token usage - Input: {ti_usage.input_tokens}, Output: {ti_usage.output_tokens}")

    agent_total_tokens = agent_total_input + agent_total_output

    # --- MonitoredModel 统计的内部消耗 ---
    internal_usage = monitored_model.total_token_usage
    internal_total_tokens = internal_usage.input_tokens + internal_usage.output_tokens
    print(f"Model-level (internal) total token usage - Input: {internal_usage.input_tokens}, Output: {internal_usage.output_tokens}\n")

    # --- 计算最终总和 ---
    grand_total_tokens = agent_total_tokens + internal_total_tokens
    print(f"Grand total token usage (all sources) = {grand_total_tokens}\n")
    
    # 返回用于存储的字典
    return {
        "total_token_count": grand_total_tokens,
        "supervisor_token": supervisor_total_tokens + summary_total_usage,
    }


def create_supervisor_callback(supervisor: SupervisorAgent, example: dict):
    """
    创建一个回调函数，该函数会根据条件选择性地调用 supervisor，
    并精确传递局部和全局任务上下文。
    
    Args:
        supervisor (SupervisorAgent): 监督代理实例。
        example (dict): 当前正在运行的 GAIA 数据集条目。
    """
    # parent_agent_map 是一个简化的方式来追踪调用关系
    # 在实际框架中，您可能需要一个更稳健的调用栈来管理

    def supervisor_callback(memory_step: ActionStep, agent: CodeAgent):
        if isinstance(memory_step, ActionStep):
            global_task = example["question"]   # 全局任务始终是原始问题
            if agent.name == "manager_agent":
                clean_local_task = ""   # 防止global和local重复
            else:
                clean_local_task = extract_clean_task(agent.task)   # 原始局部任务提取
            supervisor.record_step(
                step=memory_step,
                agent_name=agent.name,
                # memory=agent.memory,
                local_task=clean_local_task
            )

            should_supervise = False
            
            # 优先级1：处理错误
            ### TODO: 可能需要手动识别error
            if memory_step.error is not None:
                # 默认情况下，只要有错误就应该监督
                should_supervise = True
                
                # **唯一的例外**：如果错误是由 'find_archived_url' 工具引起的，我们选择忽略
                # 防御性地检查 tool_calls 是否存在且非空
                if memory_step.tool_calls and memory_step.tool_calls[0].name == "find_archived_url":
                    should_supervise = False # 覆盖为不监督
                    print(f"--- [Supervisor] Ignoring a known/safe error from 'find_archived_url'. ---")
            
            # 仅在没有错误时，才检查其他条件
            if not should_supervise:
                ### TODO: 合适的content length
                ### 白天测试找一下合适的阈值
                if (memory_step.observations and len(memory_step.observations) > 3000):
                    should_supervise = True
                ### TODO: 适配数据集
                elif supervisor._check_for_inefficiency(agent.name)[0]:
                    should_supervise = True
                
                ### 可能不太需要了
                elif (memory_step.observations and "### 1. Task outcome (short version):" in memory_step.observations):
                    should_supervise = True

            is_final_answer_attempt = False
            proposed_answer = None

            if should_supervise:    # or is_final_answer_attempt:
                supervisor.supervise_and_correct(
                    step=memory_step, 
                    agent_name=agent.name, 
                    local_task=clean_local_task,
                    global_task=global_task,
                    is_final_check=is_final_answer_attempt,
                    proposed_answer = proposed_answer
                )
            else:
                print(f"--- [Supervisor] Skipped review for step {memory_step.step_number} of '{agent.name}'. ---")
            
    return supervisor_callback


def extract_clean_task(full_task_prompt: str) -> str:
    """
    从完整的增强指令中提取干净的核心任务描述。
    假定核心任务位于两个 '---' 分隔符之间。
    """
    try:
        # 1. 使用 '---' 分割字符串，这会将其分为三部分
        parts = full_task_prompt.split('---')
        if len(parts) >= 3:
            # 2. 取中间部分，即包含 "Task:" 和核心任务的部分
            task_block = parts[1]
            
            # 3. 清理前缀和首尾的空白字符
            # 首先移除 "Task:" (以及它前后的换行符)
            if "Task:" in task_block:
                task_content = task_block.split("Task:", 1)[1]
            else:
                task_content = task_block
            
            return task_content.strip()
        else:
            # 如果格式不匹配，返回原始任务以保证鲁棒性
            return full_task_prompt
    except Exception:
        # 发生任何异常时，安全地返回原始输入
        return full_task_prompt


def create_verification_agent(model: Model) -> CodeAgent:
    """
    创建一个轻量级的、为 Supervisor 优化的、能力全面的验证 agent。
    该 agent 基于 CodeAgent 以确保与被监督 agent 的能力对等。
    """
    browser = SimpleTextBrowser(**BROWSER_CONFIG)
    # 为验证 agent 设置更低的文本限制，以控制 token 消耗
    text_limit = 10000
    
    # 为 Supervisor Agent 配备包括 visualizer 在内的完整工具集，以确保能力对等
    VERIFICATION_TOOLS = [
        GoogleSearchTool(provider="serper"),
        VisitTool(browser),
        PageUpTool(browser),
        PageDownTool(browser),
        FinderTool(browser),
        FindNextTool(browser),
        ArchiveSearchTool(browser),
        TextInspectorTool(model, text_limit),
        visualizer, # 赋予 Supervisor 视觉能力
    ]

    # 之前一直用的 CodeAgent，先已改成 ToolCallingAgent
    verification_agent = ToolCallingAgent(
        model=model,
        tools=VERIFICATION_TOOLS,
        max_steps=8,  # 降低 max_steps 以控制成本
        verbosity_level=0,  # 关闭详细日志，保持 Supervisor 的输出干净
        name="verification_agent",
        description="A specialized agent for quick fact-checking and information verification.",
        # 允许所有 import，因为它需要能够执行其工具可能需要的任何代码
        # additional_authorized_imports=["*"], 
        # 可以在模型层面增加 token 上限，例如：
        # model=LiteLLMModel(model_id=model.model_id, max_tokens=1024)
    )

    # 优化 verification_agent 的 prompt，使其更专注于高效检索和验证
    verification_agent.prompt_templates["system_prompt"] += """
    You are a highly efficient fact-checking assistant. Your primary goal is to quickly and accurately verify, correct, or summarize information based on a given task.
    - Be precise and concise.
    - Avoid deep, exploratory searches.
    - Execute the task, find the specific information needed, and use the 'final_answer' tool immediately to return the result.
    - Your response should be direct and factual.
    Here are the tools you can use:
    """ + "\n".join([tool.to_code_prompt() for tool in verification_agent.tools.values()])
    
    return verification_agent


def create_agent_team(model: Model):
    manager_agent = CodeAgent(
        model=model,
        tools=[visualizer], # [visualizer, ti_tool],
        max_steps=12,
        verbosity_level=2,
        additional_authorized_imports=["*"],
        planning_interval=4,
        # managed_agents=[text_webbrowser_agent],
    )
    # return manager_agent
    return {
        "manager": manager_agent,
        # "web_search": text_webbrowser_agent,
        # "text_inspector": ti_tool
        }


def create_agent_team_with_supervisor(model: Model, summary_model: Model,  example: dict, kb_manager: SupervisorKBManager):
    # === 1. 创建专门用于 Supervisor 的、能力完备的 Verification Agent ===
    verification_agent = create_verification_agent(model)

    # === 2. 创建 Supervisor 并注入 Verification Agent ===
    supervisor = SupervisorAgent(model=model, summary_model=summary_model, verification_agent=verification_agent, kb_manager=kb_manager)
    # 使用 example 创建一个具体的回调实例
    ## 这里需要调整调用的条件
    supervisor_callback_instance = create_supervisor_callback(supervisor, example)    

    # === 4. 创建顶层的 Manager Agent ===
    manager_agent = CodeAgent(
        model=model,
        tools=[visualizer], # Manager 自身也保留视觉能力
        max_steps=12,
        verbosity_level=2,
        additional_authorized_imports=["*"],
        planning_interval=4,
        step_callbacks=[supervisor_callback_instance],
        name="manager_agent",
    )
    
    # (可选) 如果你还想监督 main_web_agent 的内部步骤，也可以为其添加回调
    # main_web_agent.step_callbacks = {ActionStep: supervisor_callback_instance}
    
    # return manager_agent
    return {
        "manager": manager_agent,
        "supervisor": supervisor,
        "verification": verification_agent,
        "summary_model": summary_model,
    }


def to_builtin_type(obj):
    # 如果是 dict，递归处理每个 kv
    if isinstance(obj, dict):
        return {k: to_builtin_type(v) for k, v in obj.items()}
    # 如果是 list，递归处理每个元素
    elif isinstance(obj, list):
        return [to_builtin_type(i) for i in obj]
    # 发现 pyarrow、pandas、numpy 的 Integer 对象，都转成 int
    elif type(obj).__name__ == "Integer":
        return int(obj)
    # 支持 numpy 的 int
    try:
        import numpy as np
        if isinstance(obj, np.integer):
            return int(obj)
    except ImportError:
        pass
    return obj


def append_answer(entry: dict, jsonl_file: str) -> None:
    jsonl_path = Path(jsonl_file)
    jsonl_path.parent.mkdir(parents=True, exist_ok=True)
    with append_answer_lock, open(jsonl_file, "a", encoding="utf-8") as fp:
        fp.write(json.dumps(to_builtin_type(entry), ensure_ascii=False, indent=2) + "\n")
    assert jsonl_path.exists(), "File not found!"
    print("Answer exported to file:", jsonl_path.resolve())
    records = load_jsonl_multiline(jsonl_file)
    print("#" * 100)
    print(sum(r.get('is_correct', False) for r in records) / len(records) if records else 0)
    print("#" * 100)

# 把 intermediate_steps 里的 ChatMessage 对象都转成 dict
# added on 0825
def serialize_intermediate_steps(steps):
    serialized = []
    for step in steps:
        if hasattr(step, "content"):
            serialized.append({
                "role": getattr(step, "role", None),
                "content": step.content,
            })
        else:
            # 如果不是 ChatMessage，直接转成 str
            serialized.append(str(step))
    return serialized


# 优化最终答案提取逻辑
def extract_final_answer(result: any) -> str | None:
    """
    Robustly extracts the final answer from the various possible return types
    of the agent's run method.
    """
    final_answer = None

    # 1. 检查是否为 RunResult 对象 (当 return_full_result=True)
    if isinstance(result, RunResult):
        print("-> Result is a RunResult object. Extracting from .output attribute.")
        final_answer = result.output

    # 2. 检查是否为 FinalAnswerStep 对象 (这是一种不太可能但安全的检查)
    elif isinstance(result, FinalAnswerStep):
        print("-> Result is a FinalAnswerStep object. Extracting from .output attribute.")
        final_answer = result.output
        
    # 3. 检查是否为字符串
    elif isinstance(result, str):
        print("-> Result is a string. Parsing string for answer.")
        # 有可能字符串里包含 "Final answer:", 也可能就是纯净答案
        if "Final answer:" in result:
            final_answer = result.split("")[-1].strip()
        else:
            # 如果没有特定前缀，我们假设整个字符串就是答案
            final_answer = result.strip()
            
    # 如果 final_answer 仍然是 None 或空字符串，返回 None
    if final_answer and str(final_answer).strip():
        return str(final_answer).strip()
    
    return None


def answer_single_question(
    example: dict,  model_id: str, model_id_summary, answers_file: str, visual_inspection_tool: TextInspectorTool, local_model_path: str = None, supervisor: bool = False, kb_manager: SupervisorKBManager = None
) -> None:
    
    prep_token_usage = TokenUsage()
    model_params_summary: dict[str, Any] = {
        "model_id": model_id_summary,
        "custom_role_conversions": custom_role_conversions,
    }

    model_raw_summary = LiteLLMModel(**model_params_summary)
    summary_model = MonitoredModel(model_raw_summary)
    
    
    local_model_path= None
    if local_model_path:
        from smolagents import OpenAIServerModel, TransformersModel
        import torch
        model_raw = TransformersModel(
            model_id=local_model_path,
            device_map="mps",
            torch_dtype="float16", 
        )
        model = MonitoredModel(model_raw)
    else:
        load_dotenv()  # Load environment variables from .env file
        LLM_MODEL_NAME = os.environ.get("QWEN_LLM_MODEL_NAME", "qwen3-32b")
        LLM_API_KEY = os.environ.get("QWEN_LLM_API_KEY", "")
        LLM_BASE_URL = os.environ.get("QWEN_LLM_BASE_URL", "")
        from smolagents import OpenAIServerModel
        model_raw = OpenAIServerModel(
            model_id=LLM_MODEL_NAME,
            api_base=LLM_BASE_URL,
            api_key=LLM_API_KEY,
        )
        model = MonitoredModel(model_raw)
    
    agent_dict: dict[str, any] = {}
    if supervisor:
        agent_dict = create_agent_team_with_supervisor(model, summary_model, example, kb_manager=kb_manager)
        agent = agent_dict["manager"]
        print("=== Using Supervisor Agent ===")
    else:
        agent_dict = create_agent_team(model)
        agent = agent_dict["manager"]
        print("=== Using Standard Agent Team ===")

    # 对原始问题进行增强
    augmented_question = """You have one question to answer. It is paramount that you provide a correct answer.
Give it all you can: I know for a fact that you have access to all the relevant tools to solve it and find the correct answer (the answer does exist).
Failure or 'I cannot answer' or 'None found' will not be tolerated, success will be rewarded.
Run verification steps if that's needed, you must make sure you find the correct answer! Here is the task:

""" + example["question"]

    start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    try:
        # 1. 运行 Agent 并获取原始输出
        raw_result = agent.run(augmented_question, return_full_result=True)
        raw_final_result = raw_result.output
        print(f"> Raw final result: {raw_final_result}")
        agent_memory = agent.write_memory_to_messages()
        # output = extract_final_answer(raw_final_result)

        output, prep_token_usage, prep_prompt = prepare_response(augmented_question, agent_memory, reformulation_model=model_raw)

        for memory_step in agent.memory.steps:
            memory_step.model_input_messages = None
        intermediate_steps = agent_memory

        # Check for parsing errors which indicate the LLM failed to follow the required format
        # parsing_error = True if any(["AgentParsingError" in step for step in intermediate_steps]) else False
        parsing_error = any(
            "AgentParsingError" in (step.content if hasattr(step, "content") else str(step))
            for step in intermediate_steps
        )

        # check if iteration limit exceeded
        iteration_limit_exceeded = True if "Agent stopped due to iteration limit or time limit." in output else False
        raised_exception = False

    except Exception as e:
        print("Error on ", augmented_question, e)
        print("------------------- TRACEBACK START -------------------")
        traceback.print_exc()
        print("-------------------- TRACEBACK END --------------------")
        output = None
        raw_final_result = "Failure"
        intermediate_steps = []
        parsing_error = False
        iteration_limit_exceeded = False
        exception = e
        raised_exception = True

    end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    total_token_counts = calculate_and_log_token_usage(
        agent_dict=agent_dict,
        monitored_model=model,
        prep_token_usage=prep_token_usage if 'prep_token_usage' in locals() else TokenUsage(), # 确保变量存在
        supervisor=supervisor
    )

    annotated_example = {
        "agent_name": model.model_id,
        "question": example["question"],
        "augmented_question": augmented_question,
        "raw_prediction": raw_final_result,
        "prediction": output,
        "true_answer": example["true_answer"],
        # "is_correct": compute_cor(output, str(example["true_answer"])) if output else False,
        "token_counts": total_token_counts,
        # "intermediate_steps": intermediate_steps,
        "intermediate_steps": serialize_intermediate_steps(intermediate_steps),
        "parsing_error": parsing_error,
        "iteration_limit_exceeded": iteration_limit_exceeded,
        "agent_error": str(exception) if raised_exception else None,
        # "prep_prompt": serialize_intermediate_steps(prep_prompt),
        "task_id": example["task_id"],
        "start_time": start_time,
        "end_time": end_time,
    }
    append_answer(annotated_example, answers_file)


def load_jsonl_multiline(path: str):
    """支持多行格式的 jsonl 文件读取"""
    records = []
    with open(path, "r", encoding="utf-8") as f:
        buffer = ""
        for line in f:
            line = line.strip()
            if not line:
                continue
            buffer += line
            # 尝试解析
            try:
                record = json.loads(buffer)
                records.append(record)
                buffer = ""  # 清空，等待下一个对象
            except json.JSONDecodeError:
                # 说明还没凑成一个完整 JSON
                continue
    return records


def load_aime_dataset(run_set, level=None, dataset_path="./data/aime"):
    eval_ds = datasets.load_dataset(
        "parquet",  # 明确指定格式
        data_files={
            "test": f"{dataset_path}/aime*.parquet"
        },
        split=run_set
    )
    eval_ds = eval_ds.rename_column("Problem", "question")
    eval_ds = eval_ds.rename_column("Answer", "true_answer")
    eval_ds = eval_ds.rename_column("ID", "task_id")
    # if level is not None:
    #     eval_ds = eval_ds.filter(lambda x: x["level"] == level)
        
    # print(len(eval_ds.filter(lambda x: x["level"] == 5)))
    return eval_ds


def get_examples_to_answer(answers_file: str, eval_ds: datasets.Dataset, task_id: str = None) -> list[dict]:
    # return [line for line in eval_ds.to_list()]
    print(f"Loading answers from {answers_file}...")
    try:
        records = load_jsonl_multiline(answers_file)
        done_questions = [r["question"] for r in records if "question" in r]
        print(f"Found {len(done_questions)} previous results!")
    except Exception as e:
        print("Error when loading records: ", e)
        print("No usable records! ▶️ Starting new.")
        done_questions = []

    examples = [line for line in eval_ds.to_list() if line["question"] not in done_questions]

    # 如果指定了 task_id，就只保留这一条
    if task_id:
        examples = [ex for ex in examples if str(ex.get("task_id")) == str(task_id)]
        print(f"Filtered for task_id={task_id}, {len(examples)} tasks found.")
    return examples


def main():
    args = parse_args()
    eval_ds = load_aime_dataset(run_set=args.set_to_run, dataset_path=args.dataset_path)
    
    answers_file = f"output/{args.run_name}.jsonl"
    raw_tasks_to_run = get_examples_to_answer(answers_file, eval_ds, args.task_id)
    tasks_to_run = raw_tasks_to_run
    
    with ThreadPoolExecutor(max_workers=args.concurrency) as exe:
        futures = [
            exe.submit(answer_single_question, example, args.model_id, args.model_id_summary, answers_file, visualizer, args.local_model_path, args.supervisor)
            for example in tasks_to_run
        ]
        for f in tqdm(as_completed(futures), total=len(tasks_to_run), desc="Processing tasks"):
            f.result()


if __name__ == "__main__":
    main()