import os
import json
import asyncio
from openai import AsyncOpenAI
from pathlib import Path
import logging
from tqdm.asyncio import tqdm_asyncio
from typing import Dict, Any, List

# --- 配置 ---
# 设置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# API 配置 (与 qa_generator.py 保持一致)
API_KEY = ""
BASE_URL = ""

# 文件路径配置
PROJECT_ROOT = Path("/root/gMad")
TARGET_FILE = PROJECT_ROOT / "2_qa_generator" / "generated_answers_8.json"

# 并发控制
CONCURRENT_REQUESTS = 10

# 错误识别
# 我们检查需要重试的错误信息前缀，因为请求ID会变化
ERROR_MESSAGE_PREFIXES = [
    "Error: Error code: 429",  # 请求过于频繁
    "Error: Error code: 401"   # 令牌额度已用尽
]

# --- 核心功能 ---

def is_retry_needed_error(answer: str) -> bool:
    """
    检查答案是否包含需要重试的错误信息。
    """
    if not answer or not isinstance(answer, str):
        return True
    
    answer_stripped = answer.strip()
    return any(answer_stripped.startswith(prefix) for prefix in ERROR_MESSAGE_PREFIXES)

async def get_answer_with_retry(client: AsyncOpenAI, question: str, model: str, semaphore: asyncio.Semaphore) -> str:
    """
    异步调用API获取答案，如果失败则无限重试。
    """
    async with semaphore:
        while True:
            try:
                chat_completion = await client.chat.completions.create(
                    messages=[{"role": "user", "content": question}],
                    model=model,
                )
                answer = chat_completion.choices[0].message.content

                # 检查回答是否有效
                if answer and not is_retry_needed_error(answer):
                    logging.info(f"成功为模型 '{model}' 重新生成了回答: '{question[:30]}...'")
                    return answer
                else:
                    logging.warning(f"模型 '{model}' 返回了无效回答。将在5秒后重试...")
                    await asyncio.sleep(5)

            except Exception as e:
                logging.error(f"模型 '{model}' 在重新生成时API调用失败: {e}。将在5秒后重试...")
                await asyncio.sleep(5)

async def process_single_failed_item(client: AsyncOpenAI, item: Dict, model_to_retry: str, semaphore: asyncio.Semaphore):
    """
    为一个失败的条目重新生成答案并更新原数据。
    """
    question_text = item.get("question")
    if not question_text:
        return
        
    new_answer = await get_answer_with_retry(client, question_text, model_to_retry, semaphore)
    item["answers"][model_to_retry] = new_answer

async def main():
    """
    主执行函数：加载数据，查找并修复无效回答，然后保存。
    """
    if not API_KEY:
        logging.error("错误: API 密钥未配置。")
        return

    if not TARGET_FILE.exists():
        logging.error(f"错误: 目标文件未找到于 {TARGET_FILE}")
        return

    # 加载需要处理的数据
    with open(TARGET_FILE, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    logging.info(f"从 {TARGET_FILE} 加载了 {len(data)} 个问答条目。")

    # 初始化 OpenAI 异步客户端和并发信号量
    client = AsyncOpenAI(api_key=API_KEY, base_url=BASE_URL)
    semaphore = asyncio.Semaphore(CONCURRENT_REQUESTS)

    # 查找所有需要重试的条目
    tasks = []
    for item in data:
        answers = item.get("answers", {})
        for model, answer in answers.items():
            if is_retry_needed_error(answer):
                logging.info(f"发现失败的回答 -> 问题: '{item.get('question', '')[:30]}...', 模型: {model}")
                # 为每个失败的回答创建一个重试任务
                tasks.append(process_single_failed_item(client, item, model, semaphore))

    if not tasks:
        logging.info("未发现需要重新生成的回答。脚本执行完毕。")
        return
    
    logging.info(f"共发现 {len(tasks)} 个失败的回答，开始重新生成...")

    # 并发执行所有重试任务
    await tqdm_asyncio.gather(
        *tasks, desc="正在重新生成失败的回答", unit="个回答"
    )

    # 保存修复后的结果
    with open(TARGET_FILE, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

    logging.info(f"所有失败的回答已成功重新生成并保存回 {TARGET_FILE}")

if __name__ == "__main__":
    if os.name == 'nt':
        asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
    
    asyncio.run(main())
