import os
import json
import asyncio
from openai import AsyncOpenAI
from pathlib import Path
import logging
from tqdm.asyncio import tqdm_asyncio
from typing import Dict, Any

# --- 配置 ---
# 设置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# API 配置
# 注意：API密钥需要通过环境变量 OPENAI_API_KEY 设置
# 例如: export OPENAI_API_KEY='your_api_key'
API_KEY =   ""
BASE_URL = ""

# 模型列表
MODELS_TO_QUERY = [
    "o3-mini",
    "gpt-4.1-mini",
    "deepseek-v3",
    "qwen3-235b-a22b-instruct-2507",
    "gemini-2.5-pro"
]

# 文件路径配置
# 假设项目根目录是 /root/gMad
PROJECT_ROOT = Path("/root/gMad")
INPUT_FILE = PROJECT_ROOT / "1_data_loader" / "aggregated_1000_questions_10.json"
OUTPUT_FILE = PROJECT_ROOT / "2_qa_generator" / "generated_answers_10.json"

# 并发控制
CONCURRENT_REQUESTS_PER_MODEL = 5

# --- 核心功能 ---

async def get_answer(client: AsyncOpenAI, question: str, model: str, semaphore: asyncio.Semaphore) -> Dict[str, Any]:
    """
    异步调用OpenAI API获取单个问题的答案。

    Args:
        client: AsyncOpenAI 客户端实例.
        question: 要提问的问题字符串.
        model: 要使用的模型名称.
        semaphore: 用于控制并发请求数量的信号量.

    Returns:
        一个包含模型名称和其回答的字典.
    """
    async with semaphore:
        try:
            chat_completion = await client.chat.completions.create(
                messages=[
                    {
                        "role": "user",
                        "content": question,
                    }
                ],
                model=model,
            )
            answer = chat_completion.choices[0].message.content
            logging.info(f"成功获取模型 '{model}' 对问题 '{question[:30]}...' 的回答.")
            return {"model": model, "answer": answer}
        except Exception as e:
            logging.error(f"模型 '{model}' 在回答问题 '{question[:30]}...' 时出错: {e}")
            return {"model": model, "answer": f"Error: {e}"}

async def process_question(client: AsyncOpenAI, question_item: Dict, semaphore: asyncio.Semaphore) -> Dict:
    """
    为一个问题并发调用所有指定模型并收集答案。

    Args:
        client: AsyncOpenAI 客户端实例.
        question_item: 包含 'question' 键的字典.
        semaphore: 用于控制并发请求数量的信号量.

    Returns:
        一个包含原始问题和所有模型答案的字典.
    """
    question_text = question_item.get("question")
    if not question_text:
        return {**question_item, "answers": {}}

    tasks = [get_answer(client, question_text, model, semaphore) for model in MODELS_TO_QUERY]
    results = await asyncio.gather(*tasks)

    answers = {res["model"]: res["answer"] for res in results}
    
    return {
        **question_item,
        "answers": answers
    }

async def main():
    """
    主执行函数: 加载数据，并发处理问题，并保存结果。
    """
    if not API_KEY:
        logging.error("错误: 环境变量 OPENAI_API_KEY 未设置。")
        return

    # 检查输入文件是否存在
    if not INPUT_FILE.exists():
        logging.error(f"错误: 输入文件未找到于 {INPUT_FILE}")
        return
        
    # 创建输出目录
    OUTPUT_FILE.parent.mkdir(exist_ok=True)

    # 加载问题数据
    with open(INPUT_FILE, 'r', encoding='utf-8') as f:
        questions_data = json.load(f)

    logging.info(f"从 {INPUT_FILE} 加载了 {len(questions_data)} 个问题。")

    # 初始化 OpenAI 异步客户端
    client = AsyncOpenAI(api_key=API_KEY, base_url=BASE_URL)
    
    # 创建信号量以控制并发
    semaphore = asyncio.Semaphore(CONCURRENT_REQUESTS_PER_MODEL * len(MODELS_TO_QUERY))

    # 创建并执行所有问题的处理任务
    tasks = [process_question(client, item, semaphore) for item in questions_data]
    
    all_results = await tqdm_asyncio.gather(
        *tasks, desc="正在为问题生成答案", unit="个问题"
    )

    # 保存最终结果
    with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
        json.dump(all_results, f, ensure_ascii=False, indent=2)

    logging.info(f"所有答案已成功生成并保存到 {OUTPUT_FILE}")


if __name__ == "__main__":
    # 在Windows上，可能需要设置不同的事件循环策略
    if os.name == 'nt':
        asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
    
    asyncio.run(main())
