import os
import json
import argparse
from openai import OpenAI
from tqdm import tqdm

# --- LLM 和 Prompt 配置 ---

# 建议通过环境变量设置 API Key，更加安全
# 在您的终端运行: export OPENAI_API_KEY='your_key_here'
# 或者: export ZJU_API_KEY='your_key_here' (如果您使用特定 base_url)
# export ZJU_API_KEY="sk-kyZpcFeu13roQ232xcWurjdw9zQjb8QYGLRaNtOqfh9qC0yx"
# export OPENAI_API_KEY= 'sk-yFPvGKGd7D5bUrdzHgIjf2TMuC4ZLxdjkMgi7WDJEybaj3Wn'
# API_KEY = os.getenv(os.getenv("OPENAI_API_KEY")) 
API_KEY = "sk-MoH9x1Sf1p312gXWsLiZElA0ZI27es72NdLpfMyB3Vxi4NLQ"
BASE_URL = "https://zjuapi.com/v1" # 根据您的需要修改
# BASE_URL = "https://35.aigcbest.top/v1" # OpenAI 官方地址
MODEL_NAME = "gpt-5"

# 优化后的 Prompt 模板


OPTIMIZED_PROMPT_TEMPLATE = """
You are a world-class expert in Multi-Agent Systems (MAS), tasked with mentoring a Supervisor Agent by analyzing past task records.
Your goal is to distill a multi-faceted "Supervisory Insight Memo" from the complete task record provided below. This memo should equip the Supervisor Agent with practical knowledge for future, similar situations.
It must be clear, concise, and effective.

CRITICAL INSTRUCTION: GENERALIZATION
Your generated insights MUST be generalizable. DO NOT mention specific named entities from the input record like "Sloan Research Fellow", "IZA", or "California". 
Instead, you must abstract them into their generic categories, such as "a prestigious fellowship", "a research institute", or "a qualitative biographical fact". 
The goal is to create a timeless, reusable principle, not a specific solution for the given task.

Instead of a single principle, analyze the record from the following angles. If an angle is relevant, provide a concise insight. If not, you may omit it.
1.  Strategic Heuristic: Is there a high-level planning or execution strategy that can be generalized?
2.  Failure Pattern / Diagnostic Signature: Does this case reveal a common way agents get stuck, make errors, or behave inefficiently? Describe the symptoms and the likely diagnosis.
3.  Corrective Action / Intervention Tactic: Based on the agent's actions or tool outputs, what is a concrete, effective intervention the Supervisor could perform?
4.  Verification Checkpoint: What type of information or claim in this task is inherently uncertain and should be a strong trigger for secondary verification?

INPUT TASK RECORD:
{input_json}

OUTPUT FORMAT
Your output MUST be a single JSON object. The keys should be one or more of the following: strategic_heuristic, failure_pattern, corrective_action, verification_checkpoint. 
The values should be the insights you generated.

YOUR OUTPUT:
"""

import concurrent.futures
from tqdm import tqdm
import json
import argparse
from openai import OpenAI

def distill_experience(client: OpenAI, task_record: dict):
    """
    调用 LLM 为单个任务记录生成 Supervisor Experience。
    """
    prompt_content = OPTIMIZED_PROMPT_TEMPLATE.format(input_json=task_record)

    response = client.chat.completions.create(
        model=MODEL_NAME,
        messages=[{"role": "user", "content": prompt_content}],
        response_format={"type": "json_object"},
    )
    total_tokens = response.usage.total_tokens if hasattr(response, "usage") else 0

    response_content = response.choices[0].message.content
    try:
        experience_data = json.loads(response_content)
        return experience_data, total_tokens
    except json.JSONDecodeError:
        return f"[Error: Failed to decode LLM response as JSON. Response was: {response_content}]", total_tokens


def process_record(client, record):
    """单独处理一条记录，便于并行调用"""
    try:
        supervisor_experience, token_cost = distill_experience(client, record)

        new_record = record.copy()
        new_record.pop("agent_experience", None)
        new_record.pop("search_agent_experience", None)
        new_record["supervisor_experience"] = supervisor_experience

        return new_record, token_cost, None
    except Exception as e:
        return None, 0, f"Error processing record with question: '{record.get('question', 'N/A')}', {e}"


def process_kb_file(input_path: str, output_path: str, concurrency: int = 4):
    if not API_KEY:
        raise ValueError("API key not found. Please set the OPENAI_API_KEY or ZJU_API_KEY environment variable.")
    client = OpenAI(api_key=API_KEY, base_url=BASE_URL)

    print(f"Loading knowledge base from: {input_path}")
    with open(input_path, 'r', encoding='utf-8') as f:
        kb_data = json.load(f)

    # 读取已处理的问题，避免重复处理
    processed_questions = set()
    if os.path.exists(output_path):
        with open(output_path, 'r', encoding='utf-8') as f:
            try:
                existing_data = json.load(f)
                for item in existing_data:
                    if "question" in item:
                        processed_questions.add(item["question"])
            except json.JSONDecodeError:
                existing_data = []
    else:
        existing_data = []

    new_kb = existing_data[:]  # 保留已完成部分
    TOTAL_TOKEN_COST = 0
    errors = []

    print(f"Starting distillation of {len(kb_data)} records with concurrency={concurrency}...")

     # === 过滤掉已经处理过的问题 ===
    remaining_records = [rec for rec in kb_data if rec.get("question") not in processed_questions]
    print(f"Starting distillation of {len(remaining_records)} remaining records with concurrency={concurrency}...")

    # 并行处理
    with concurrent.futures.ThreadPoolExecutor(max_workers=concurrency) as executor:
        futures = [executor.submit(process_record, client, record) for record in remaining_records]

        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Distilling Experiences"):
            new_record, token_cost, err = future.result()
            if err:
                if "该令牌额度已用尽" in str(err):
                    print("Token quota exhausted. Stopping further processing.")
                    break
                errors.append(err)
                print(f"Current error number: {len(errors)}")
                continue
            if new_record:
                new_kb.append(new_record)
                TOTAL_TOKEN_COST += token_cost

    # 输出结果
    print(f"\nDistillation complete. Saving {len(new_kb)} new records to: {output_path}")
    print(f"Total token cost for LLM calls: {TOTAL_TOKEN_COST} tokens")
    if errors:
        print(f"Total errors: {len(errors)}")
        for e in errors[:5]:
            print(e)  # 只显示前5个错误

    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(new_kb, f, indent=2, ensure_ascii=False)

    print("Done.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Distill Supervisor Experiences from an Agent Knowledge Base.")
    parser.add_argument("--input_file", type=str, default='/home/ofo/project_workflow_auto_generation/smolagents/examples/open_deep_research/data/agent_kb_database_refined_v2.json')
    parser.add_argument("--output_file", type=str, default='/home/ofo/project_workflow_auto_generation/smolagents/examples/open_deep_research/data/supervisor_database.json')
    parser.add_argument("--concurrency", type=int, default=16, help="Number of concurrent requests")
    args = parser.parse_args()

    process_kb_file(args.input_file, args.output_file, concurrency=args.concurrency)