# EXAMPLE COMMAND: from folder examples/open_deep_research, run: python run_gaia.py --concurrency 32 --run-name generate-traces-03-apr-noplanning --model-id gpt-4o
import argparse
import json
import os
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from pathlib import Path
from typing import Any

import datasets
import pandas as pd
from dotenv import load_dotenv
from huggingface_hub import login, snapshot_download
from scripts.reformulator import prepare_response
from scripts.run_agents import (
    get_single_file_description,
    get_zip_description,
)
from scripts.text_inspector_tool import TextInspectorTool
from scripts.text_web_browser import (
    ArchiveSearchTool,
    FinderTool,
    FindNextTool,
    PageDownTool,
    PageUpTool,
    SimpleTextBrowser,
    VisitTool,
)
from scripts.visual_qa import visualizer
from tqdm import tqdm

from smolagents import (
    CodeAgent,
    GoogleSearchTool,
    LiteLLMModel,
    Model,
    ToolCallingAgent,
    SupervisorAgent,
    SupervisorKBManager,
)

from phoenix.otel import register
from openinference.instrumentation.smolagents import SmolagentsInstrumentor

from smolagents.agents import RunResult
from smolagents.memory import FinalAnswerStep, ActionStep, TaskStep

from smolagents.models import MonitoredModel    # 用于全局监督 token cost
import re
import traceback


load_dotenv(override=True)
# login(os.getenv("HF_TOKEN"))
append_answer_lock = threading.Lock()
### IMPORTANT: EVALUATION SWITCHES
print("Make sure you deactivated any VPN like Tailscale, else some URLs will be blocked!")
custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"}
user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
BROWSER_CONFIG = {
    "viewport_size": 1024 * 5,
    "downloads_folder": "downloads_folder",
    "request_kwargs": {
        "headers": {"User-Agent": user_agent},
        "timeout": 300,
    },
    "serpapi_key": os.getenv("SERPAPI_API_KEY"),
}
os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--concurrency", type=int, default=1)
    parser.add_argument("--num-examples", type=int, default=None)
    parser.add_argument("--model-id", type=str, default="gpt-4.1")
    parser.add_argument("--model-id-summary", type=str, default="gpt-4.1")
    parser.add_argument("--run-name", type=str, required=True)
    parser.add_argument("--set-to-run", type=str, default="validation")
    # parser.add_argument("--use-open-models", type=bool, default=False)
    # parser.add_argument("--use-raw-dataset", action="store_true")
    parser.add_argument("--dataset-path", type=str, default="./data/aime")
    parser.add_argument("--supervisor", action="store_true")
    # parser.add_argument("--task-id", type=str, default=None, help="Only run the task with this specific ID")
    # parser.add_argument("--level", type=int, default=None, help="Only run tasks with this specific level")
    return parser.parse_args()


class TokenUsage:
    def __init__(self, input_tokens=0, output_tokens=0):
        self.input_tokens = input_tokens
        self.output_tokens = output_tokens


def calculate_and_log_token_usage(
    agent_dict: dict[str, any],
    monitored_model: MonitoredModel,
    prep_token_usage: TokenUsage,
    supervisor: bool
) -> dict:
    """
    计算并打印所有来源的token消耗，并返回一个用于存储的最终字典。

    Args:
        agent_dict (dict): 包含所有Agent实例的字典。
        monitored_model (MonitoredModel): 被包装的模型实例。
        prep_token_usage (TokenUsage): 从 prepare_response 返回的token消耗。
        supervisor (bool): 是否使用了supervisor模式。

    Returns:
        dict: 一个包含总消耗和supervisor消耗的字典，用于最终存储。
    """
    # --- Agent 级别的外部消耗 ---
    agent_total_input = 0
    agent_total_output = 0
    supervisor_total_tokens = 0
    summary_total_usage = 0

    # 1. Manager Agent
    manager_usage = agent_dict["manager"].monitor.get_total_token_counts()
    agent_total_input += manager_usage.input_tokens
    agent_total_output += manager_usage.output_tokens
    print(f"Manager agent token usage - Input: {manager_usage.input_tokens}, Output: {manager_usage.output_tokens}")

    # 2. Web Search Agent
    # web_search_usage = agent_dict["web_search"].monitor.get_total_token_counts()
    # agent_total_input += web_search_usage.input_tokens
    # agent_total_output += web_search_usage.output_tokens
    # print(f"Web Search agent token usage - Input: {web_search_usage.input_tokens}, Output: {web_search_usage.output_tokens}")

    # 3. Supervisor & Verification Agents
    supervisor_input = 0
    supervisor_output = 0
    if supervisor:
        supervisor_usage = agent_dict["supervisor"].monitor.get_total_token_counts()
        # verification_usage = agent_dict["verification"].monitor.get_total_token_counts()
        supervisor_input = supervisor_usage.input_tokens # + verification_usage.input_tokens
        supervisor_output = supervisor_usage.output_tokens # + verification_usage.output_tokens
        # agent_total_input += supervisor_input
        # agent_total_output += supervisor_output
        supervisor_total_tokens = supervisor_input + supervisor_output
        print(f"Supervisor/Verification token usage - Input: {supervisor_input}, Output: {supervisor_output}")
    
    # 4. prepare_response 步骤的消耗
    agent_total_input += prep_token_usage.input_tokens + supervisor_input
    agent_total_output += prep_token_usage.output_tokens + supervisor_output
    print(f"Preparation step token usage - Input: {prep_token_usage.input_tokens}, Output: {prep_token_usage.output_tokens}")

    # 5. Summary Model
    if "summary_model" in agent_dict:
        summary_model = agent_dict["summary_model"]
        summary_usage = summary_model.total_token_usage
        agent_total_input += summary_usage.input_tokens
        agent_total_output += summary_usage.output_tokens
        summary_total_usage = summary_usage.input_tokens + summary_usage.output_tokens
        print(f"Summary model token usage - Input: {summary_usage.input_tokens}, Output: {summary_usage.output_tokens}")
    
    if "text_inspector" in agent_dict:
        ti_tool = agent_dict["text_inspector"]
        ti_usage = ti_tool.get_total_token_counts()
        agent_total_input += ti_usage.input_tokens
        agent_total_output += ti_usage.output_tokens
        print(f"Text Inspector Tool token usage - Input: {ti_usage.input_tokens}, Output: {ti_usage.output_tokens}")

    agent_total_tokens = agent_total_input + agent_total_output

    # --- MonitoredModel 统计的内部消耗 ---
    internal_usage = monitored_model.total_token_usage
    internal_total_tokens = internal_usage.input_tokens + internal_usage.output_tokens
    print(f"Model-level (internal) total token usage - Input: {internal_usage.input_tokens}, Output: {internal_usage.output_tokens}\n")

    # --- 计算最终总和 ---
    grand_total_tokens = agent_total_tokens + internal_total_tokens
    print(f"Grand total token usage (all sources) = {grand_total_tokens}\n")
    
    # 返回用于存储的字典
    return {
        "total_token_count": grand_total_tokens,
        "supervisor_token": supervisor_total_tokens + summary_total_usage,
    }


def create_supervisor_callback(supervisor: SupervisorAgent, example: dict):
    """
    创建一个回调函数，该函数会根据条件选择性地调用 supervisor，
    并精确传递局部和全局任务上下文。
    
    Args:
        supervisor (SupervisorAgent): 监督代理实例。
        example (dict): 当前正在运行的 GAIA 数据集条目。
    """
    # parent_agent_map 是一个简化的方式来追踪调用关系
    # 在实际框架中，您可能需要一个更稳健的调用栈来管理

    def supervisor_callback(memory_step: ActionStep, agent: CodeAgent):
        if isinstance(memory_step, ActionStep):
            global_task = example["question"]   # 全局任务始终是原始问题
            if agent.name == "manager_agent":
                clean_local_task = ""   # 防止global和local重复
            else:
                clean_local_task = extract_clean_task(agent.task)   # 原始局部任务提取
            supervisor.record_step(
                step=memory_step,
                agent_name=agent.name,
                # memory=agent.memory,
                local_task=clean_local_task
            )

            should_supervise = False
            
            # 优先级1：处理错误
            ### TODO: 可能需要手动识别error
            if memory_step.error is not None:
                # 默认情况下，只要有错误就应该监督
                should_supervise = True
                
                # **唯一的例外**：如果错误是由 'find_archived_url' 工具引起的，我们选择忽略
                # 防御性地检查 tool_calls 是否存在且非空
                if memory_step.tool_calls and memory_step.tool_calls[0].name == "find_archived_url":
                    should_supervise = False # 覆盖为不监督
                    print(f"--- [Supervisor] Ignoring a known/safe error from 'find_archived_url'. ---")
            
            # 仅在没有错误时，才检查其他条件
            if not should_supervise:
                ### TODO: 合适的content length
                ### 白天测试找一下合适的阈值
                if (memory_step.observations and len(memory_step.observations) > 3000):
                    should_supervise = True
                ### TODO: 适配数据集
                elif supervisor._check_for_inefficiency(agent.name)[0]:
                    should_supervise = True
                
                ### 可能不太需要了
                elif (memory_step.observations and "### 1. Task outcome (short version):" in memory_step.observations):
                    should_supervise = True

            is_final_answer_attempt = False
            proposed_answer = None

            if should_supervise:    # or is_final_answer_attempt:
                supervisor.supervise_and_correct(
                    step=memory_step, 
                    agent_name=agent.name, 
                    local_task=clean_local_task,
                    global_task=global_task,
                    is_final_check=is_final_answer_attempt,
                    proposed_answer = proposed_answer
                )
            else:
                print(f"--- [Supervisor] Skipped review for step {memory_step.step_number} of '{agent.name}'. ---")
            
    return supervisor_callback


def extract_clean_task(full_task_prompt: str) -> str:
    """
    从完整的增强指令中提取干净的核心任务描述。
    假定核心任务位于两个 '---' 分隔符之间。
    """
    try:
        # 1. 使用 '---' 分割字符串，这会将其分为三部分
        parts = full_task_prompt.split('---')
        if len(parts) >= 3:
            # 2. 取中间部分，即包含 "Task:" 和核心任务的部分
            task_block = parts[1]
            
            # 3. 清理前缀和首尾的空白字符
            # 首先移除 "Task:" (以及它前后的换行符)
            if "Task:" in task_block:
                task_content = task_block.split("Task:", 1)[1]
            else:
                task_content = task_block
            
            return task_content.strip()
        else:
            # 如果格式不匹配，返回原始任务以保证鲁棒性
            return full_task_prompt
    except Exception:
        # 发生任何异常时，安全地返回原始输入
        return full_task_prompt


def create_verification_agent(model: Model) -> CodeAgent:
    """
    创建一个轻量级的、为 Supervisor 优化的、能力全面的验证 agent。
    该 agent 基于 CodeAgent 以确保与被监督 agent 的能力对等。
    """
    browser = SimpleTextBrowser(**BROWSER_CONFIG)
    # 为验证 agent 设置更低的文本限制，以控制 token 消耗
    text_limit = 10000
    
    # 为 Supervisor Agent 配备包括 visualizer 在内的完整工具集，以确保能力对等
    VERIFICATION_TOOLS = [
        GoogleSearchTool(provider="serper"),
        VisitTool(browser),
        PageUpTool(browser),
        PageDownTool(browser),
        FinderTool(browser),
        FindNextTool(browser),
        ArchiveSearchTool(browser),
        TextInspectorTool(model, text_limit),
        visualizer, # 赋予 Supervisor 视觉能力
    ]

    # 之前一直用的 CodeAgent，先已改成 ToolCallingAgent
    verification_agent = ToolCallingAgent(
        model=model,
        tools=VERIFICATION_TOOLS,
        max_steps=8,  # 降低 max_steps 以控制成本
        verbosity_level=0,  # 关闭详细日志，保持 Supervisor 的输出干净
        name="verification_agent",
        description="A specialized agent for quick fact-checking and information verification.",
        # 允许所有 import，因为它需要能够执行其工具可能需要的任何代码
        # additional_authorized_imports=["*"], 
        # 可以在模型层面增加 token 上限，例如：
        # model=LiteLLMModel(model_id=model.model_id, max_tokens=1024)
    )

    # 优化 verification_agent 的 prompt，使其更专注于高效检索和验证
    verification_agent.prompt_templates["system_prompt"] += """
    You are a highly efficient fact-checking assistant. Your primary goal is to quickly and accurately verify, correct, or summarize information based on a given task.
    - Be precise and concise.
    - Avoid deep, exploratory searches.
    - Execute the task, find the specific information needed, and use the 'final_answer' tool immediately to return the result.
    - Your response should be direct and factual.
    Here are the tools you can use:
    """ + "\n".join([tool.to_code_prompt() for tool in verification_agent.tools.values()])
    
    return verification_agent


def create_agent_team(model: Model):
    manager_agent = CodeAgent(
        model=model,
        tools=[visualizer], # [visualizer, ti_tool],
        max_steps=12,
        verbosity_level=2,
        additional_authorized_imports=["*"],
        planning_interval=4,
        # managed_agents=[text_webbrowser_agent],
    )
    # return manager_agent
    return {
        "manager": manager_agent,
        # "web_search": text_webbrowser_agent,
        # "text_inspector": ti_tool
        }


def create_agent_team_with_supervisor(model: Model, summary_model: Model,  example: dict, kb_manager: SupervisorKBManager):
    # === 1. 创建专门用于 Supervisor 的、能力完备的 Verification Agent ===
    verification_agent = create_verification_agent(model)

    # === 2. 创建 Supervisor 并注入 Verification Agent ===
    supervisor = SupervisorAgent(model=model, summary_model=summary_model, verification_agent=verification_agent, kb_manager=kb_manager)
    # 使用 example 创建一个具体的回调实例
    ## 这里需要调整调用的条件
    supervisor_callback_instance = create_supervisor_callback(supervisor, example)    

    # === 4. 创建顶层的 Manager Agent ===
    manager_agent = CodeAgent(
        model=model,
        tools=[visualizer], # Manager 自身也保留视觉能力
        max_steps=12,
        verbosity_level=2,
        additional_authorized_imports=["*"],
        planning_interval=4,
        step_callbacks=[supervisor_callback_instance],
        name="manager_agent",
    )
    
    # (可选) 如果你还想监督 main_web_agent 的内部步骤，也可以为其添加回调
    # main_web_agent.step_callbacks = {ActionStep: supervisor_callback_instance}
    
    # return manager_agent
    return {
        "manager": manager_agent,
        "supervisor": supervisor,
        "verification": verification_agent,
        "summary_model": summary_model,
    }


def load_humaneval_dataset(run_set, dataset_path="./dataset/humaneval"):
    eval_ds = datasets.load_dataset(
        "json",
        data_files={
            # "train": f"{dataset_path}/data/train*.jsonl",  # 如有训练集
            "test": f"{dataset_path}/*test.jsonl"
        },
        split=run_set
    )
    eval_ds = eval_ds.rename_column("prompt", "question")
    eval_ds = eval_ds.rename_column("canonical_solution", "true_answer")
    # 这里的 eval_ds 是一个 DatasetDict，包含 "test" split
    # 如果你需要直接返回某个 split，可加上：
    # return eval_ds[run_set]
    return eval_ds


def append_answer(entry: dict, jsonl_file: str) -> None:
    jsonl_path = Path(jsonl_file)
    jsonl_path.parent.mkdir(parents=True, exist_ok=True)
    with append_answer_lock, open(jsonl_file, "a", encoding="utf-8") as fp:
        fp.write(json.dumps(entry, ensure_ascii=False, indent=2) + "\n")
    assert jsonl_path.exists(), "File not found!"
    print("Answer exported to file:", jsonl_path.resolve())

# 把 intermediate_steps 里的 ChatMessage 对象都转成 dict
# added on 0825
def serialize_intermediate_steps(steps):
    serialized = []
    for step in steps:
        if hasattr(step, "content"):
            serialized.append({
                "role": getattr(step, "role", None),
                "content": step.content,
            })
        else:
            # 如果不是 ChatMessage，直接转成 str
            serialized.append(str(step))
    return serialized


# 优化最终答案提取逻辑
def extract_final_answer(result: any) -> str | None:
    """
    Robustly extracts the final answer from the various possible return types
    of the agent's run method.
    """
    final_answer = None

    # 1. 检查是否为 RunResult 对象 (当 return_full_result=True)
    if isinstance(result, RunResult):
        print("-> Result is a RunResult object. Extracting from .output attribute.")
        final_answer = result.output

    # 2. 检查是否为 FinalAnswerStep 对象 (这是一种不太可能但安全的检查)
    elif isinstance(result, FinalAnswerStep):
        print("-> Result is a FinalAnswerStep object. Extracting from .output attribute.")
        final_answer = result.output
        
    # 3. 检查是否为字符串
    elif isinstance(result, str):
        print("-> Result is a string. Parsing string for answer.")
        # 有可能字符串里包含 "Final answer:", 也可能就是纯净答案
        if "Final answer:" in result:
            final_answer = result.split("")[-1].strip()
        else:
            # 如果没有特定前缀，我们假设整个字符串就是答案
            final_answer = result.strip()
            
    # 如果 final_answer 仍然是 None 或空字符串，返回 None
    if final_answer and str(final_answer).strip():
        return str(final_answer).strip()
    
    return None


def answer_single_question(
    example: dict, model_id: str, model_id_summary, answers_file: str, visual_inspection_tool: TextInspectorTool, supervisor: bool = False, kb_manager: SupervisorKBManager = None
) -> None:
    
    local_model_path= None
    prep_token_usage = TokenUsage()
    if local_model_path:
        from smolagents import OpenAIServerModel, TransformersModel
        import torch
        model_raw = TransformersModel(
            model_id=local_model_path,
            device_map="mps",
            torch_dtype="float16", 
        )
        model = MonitoredModel(model_raw)
    else:
        load_dotenv()  # Load environment variables from .env file
        LLM_MODEL_NAME = os.environ.get("QWEN_LLM_MODEL_NAME", "qwen3-32b")
        LLM_API_KEY = os.environ.get("QWEN_LLM_API_KEY", "")
        LLM_BASE_URL = os.environ.get("QWEN_LLM_BASE_URL", "")
        from smolagents import OpenAIServerModel
        model_raw = OpenAIServerModel(
            model_id=LLM_MODEL_NAME,
            api_base=LLM_BASE_URL,
            api_key=LLM_API_KEY,
        )
        model = MonitoredModel(model_raw)
    

    if supervisor:
        model_params_summary: dict[str, Any] = {
            "model_id": model_id_summary,
            "custom_role_conversions": custom_role_conversions,
        }

        model_raw_summary = LiteLLMModel(**model_params_summary)
        summary_model = MonitoredModel(model_raw_summary)
        agent_dict = create_agent_team_with_supervisor(model, summary_model, example, kb_manager=kb_manager)
        agent = agent_dict["manager"]
        print("=== Using Supervisor Agent ===")
        
    else:
        agent_dict = create_agent_team(model)
        agent = agent_dict["manager"]
        print("=== Using Standard Agent Team ===")

    # 对原始问题进行增强
    augmented_question = f"""You are a Python coding assistant.
Write a correct and efficient solution to the following task.

Task:
{example["question"]}
"""

    start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    try:
        # 1. 运行 Agent 并获取原始输出
        raw_result = agent.run(augmented_question, return_full_result=True)
        raw_final_result = raw_result.output
        print(f"> Raw final result: {raw_final_result}")
        agent_memory = agent.write_memory_to_messages()
        # output = extract_final_answer(raw_final_result)

        output, prep_token_usage, prep_prompt = prepare_response(augmented_question, agent_memory, reformulation_model=model_raw)

        for memory_step in agent.memory.steps:
            memory_step.model_input_messages = None
        intermediate_steps = agent_memory

        # Check for parsing errors which indicate the LLM failed to follow the required format
        # parsing_error = True if any(["AgentParsingError" in step for step in intermediate_steps]) else False
        parsing_error = any(
            "AgentParsingError" in (step.content if hasattr(step, "content") else str(step))
            for step in intermediate_steps
        )

        # check if iteration limit exceeded
        iteration_limit_exceeded = True if "Agent stopped due to iteration limit or time limit." in output else False
        raised_exception = False

    except Exception as e:
        print("Error on ", augmented_question, e)
        print("------------------- TRACEBACK START -------------------")
        traceback.print_exc()
        print("-------------------- TRACEBACK END --------------------")
        output = None
        raw_final_result = "Failure"
        intermediate_steps = []
        parsing_error = False
        iteration_limit_exceeded = False
        exception = e
        raised_exception = True

    end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    total_token_counts = calculate_and_log_token_usage(
        agent_dict=agent_dict,
        monitored_model=model,
        prep_token_usage=prep_token_usage if 'prep_token_usage' in locals() else TokenUsage(), # 确保变量存在
        supervisor=supervisor
    )

    annotated_example = {
        "agent_name": model.model_id,
        "question": example["question"],
        "augmented_question": augmented_question,
        # "Level": example["level"],
        "raw_prediction": str(raw_final_result) if callable(raw_final_result) else raw_final_result,
        "prediction": output,
        "true_answer": example["true_answer"],
        "token_counts": total_token_counts,
        # "intermediate_steps": intermediate_steps,
        "intermediate_steps": serialize_intermediate_steps(intermediate_steps),
        "parsing_error": parsing_error,
        "iteration_limit_exceeded": iteration_limit_exceeded,
        "agent_error": str(exception) if raised_exception else None,
        # "prep_prompt": serialize_intermediate_steps(prep_prompt),
        "task_id": example["task_id"],
        "start_time": start_time,
        "end_time": end_time,
    }
    bad = find_nonserializable(annotated_example)
    print("Bad field:", bad)
    
    append_answer(annotated_example, answers_file)


def find_nonserializable(obj, path="root"):
    try:
        json.dumps(obj)
        return None
    except TypeError as e:
        if isinstance(obj, dict):
            for k, v in obj.items():
                bad = find_nonserializable(v, f"{path}.{k}")
                if bad:
                    return bad
        elif isinstance(obj, list):
            for i, v in enumerate(obj):
                bad = find_nonserializable(v, f"{path}[{i}]")
                if bad:
                    return bad
        else:
            return (path, type(obj).__name__, obj)
    return None

def load_jsonl_multiline(path: str):
    """支持多行格式的 jsonl 文件读取"""
    records = []
    with open(path, "r", encoding="utf-8") as f:
        buffer = ""
        for line in f:
            line = line.strip()
            if not line:
                continue
            buffer += line
            # 尝试解析
            try:
                record = json.loads(buffer)
                records.append(record)
                buffer = ""  # 清空，等待下一个对象
            except json.JSONDecodeError:
                # 说明还没凑成一个完整 JSON
                continue
    return records

def get_examples_to_answer(answers_file: str, eval_ds: datasets.Dataset, task_id: str = None) -> list[dict]:
    # return [line for line in eval_ds.to_list()]
    print(f"Loading answers from {answers_file}...")
    try:
        records = load_jsonl_multiline(answers_file)
        done_questions = [r["question"] for r in records if "question" in r]
        print(f"Found {len(done_questions)} previous results!")
    except Exception as e:
        print("Error when loading records: ", e)
        print("No usable records! ▶️ Starting new.")
        done_questions = []

    examples = [line for line in eval_ds.to_list() if line["question"] not in done_questions]
    return examples


def main():
    args = parse_args()
    eval_ds = load_humaneval_dataset(run_set=args.set_to_run, dataset_path=args.dataset_path)
    
    answers_file = f"output/{args.set_to_run}/{args.run_name}.jsonl"
    tasks_to_run = get_examples_to_answer(answers_file, eval_ds)
    print(answers_file)

    with ThreadPoolExecutor(max_workers=args.concurrency) as exe:
        futures = [
            exe.submit(answer_single_question, example, args.model_id, args.model_id_summary, answers_file, visualizer, args.supervisor)
            for example in tasks_to_run
        ]
        for f in tqdm(as_completed(futures), total=len(tasks_to_run), desc="Processing tasks"):
            f.result()

    print("All tasks processed.")


if __name__ == "__main__":
    main()