import os
import json
import time
import functools
import concurrent.futures
from openai import OpenAI
from dotenv import load_dotenv
from tqdm import tqdm
import logging

# --- 1. 配置 ---
# 从 .env 文件加载环境变量
load_dotenv()

# API 配置
LLM_API_KEY = os.environ.get("YUNWU_API_KEY")
LLM_API_BASE = "https://yunwu.ai/v1"
LLM_JUDGE_MODEL = "gemini-2.5-pro-thinking-128"

# 文件路径配置
INPUT_FILE = "mmau_results_base.jsonl"
OUTPUT_FILE = "mmau_results_base_llm.jsonl"

MAX_WORKERS = 128


def get_processed_ids(output_file_path: str) -> set:
    """读取已存在的输出文件，返回一个包含所有已处理 scene_id 的集合，实现断点续评。"""
    processed_ids = set()
    if not os.path.exists(output_file_path):
        return processed_ids

    with open(output_file_path, "r", encoding="utf-8") as f:
        for line in f:
            try:
                data = json.loads(line)
                if "scene_id" in data:
                    processed_ids.add(data["scene_id"])
            except json.JSONDecodeError:
                logging.warning(f"输出文件中发现无效的JSON行: {line.strip()}")
    return processed_ids


def _call_judge_llm(client: OpenAI, prompt: str, scene_id: str) -> str:
    """
    一个独立的、带重试逻辑的LLM API调用函数。
    借鉴了你提供的 utils.py 中的模式。
    """
    max_retries = 3
    base_delay = 5  # seconds
    for attempt in range(max_retries):
        try:
            response = client.chat.completions.create(
                model=LLM_JUDGE_MODEL,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.0,
            )
            return response.choices[0].message.content
        except Exception as e:
            logging.error(
                f"场景ID {scene_id} API调用失败 (尝试 {attempt + 1}/{max_retries}): {e}"
            )
            if attempt < max_retries - 1:
                delay = base_delay * (2**attempt)
                logging.info(f"将在 {delay} 秒后重试...")
                time.sleep(delay)
            else:
                logging.error(f"场景ID {scene_id} 在所有重试后API调用均失败。")
                return ""  # 返回空字符串表示彻底失败
    return ""


def get_llm_judge_verdict(client: OpenAI, data_to_judge: dict) -> dict:
    """
    调用LLM Judge进行判断，并返回带有评测结果的完整数据对象。
    此版本包含了更健壮的输出解析逻辑。
    """
    prompt_template = """你是一个负责批改单选题的AI助手。你的任务是判断一个模型的答案是否正确。你必须且只能返回 "true" 或 "false"。

# 问题、选项和正确答案
- **问题**: {task}
- **正确选项**: {ground_truth_option}

# 模型的原始输出 (需要你来判断)
{raw_output}


# 你的任务
请仔细阅读模型的原始输出，判断它最终选择的答案是否是正确选项 '{ground_truth_option}'？
你的回答只能是单个单词 "true" 或 "false"，不要包含任何其他字符。
"""
    prompt = prompt_template.format(
        task=data_to_judge["prompt"],
        ground_truth_option=data_to_judge["ground_truth"],
        raw_output=data_to_judge["raw_output"],
    )

    scene_id = data_to_judge.get("scene_id", "N/A")
    raw_output = _call_judge_llm(client, prompt, scene_id)

    # print(raw_output)

    is_correct = False  # 默认值为 False
    if raw_output:
        cleaned_output = raw_output.strip().lower()
        # **核心修改：使用 'in' 关键字进行柔性判断**
        if "true" in cleaned_output:
            is_correct = True
        elif "false" in cleaned_output:
            is_correct = False
        else:
            logging.warning(
                f"场景ID {scene_id} 的模型返回内容无法明确解析: '{raw_output}'. 判定为 False。"
            )
    else:
        logging.error(f"场景ID {scene_id} 未能从LLM获取任何有效返回。")

    data_to_judge["llm_judge_is_correct"] = is_correct
    return data_to_judge


# --- 3. 主逻辑 ---


def main():
    """主执行函数，负责编排多线程评测任务。"""
    if not LLM_API_KEY:
        raise ValueError("API Key 'YUNWU_API_KEY' 未在 .env 文件中设置。")

    client = OpenAI(api_key=LLM_API_KEY, base_url=LLM_API_BASE)
    processed_ids = get_processed_ids(OUTPUT_FILE)

    if processed_ids:
        logging.info(f"检测到 {len(processed_ids)} 条已评测记录，将从断点处继续...")

    tasks_to_run = []
    try:
        with open(INPUT_FILE, "r", encoding="utf-8") as infile:
            for line in infile:
                try:
                    data = json.loads(line)
                    if data.get("scene_id") not in processed_ids:
                        tasks_to_run.append(data)
                except json.JSONDecodeError:
                    logging.warning(
                        f"输入文件中发现无效的JSON行，已跳过: {line.strip()}"
                    )
    except FileNotFoundError:
        logging.error(f"错误: 输入文件 '{INPUT_FILE}' 未找到。")
        return

    if not tasks_to_run:
        logging.info("没有需要评测的新任务。程序退出。")
        return

    logging.info(f"总计 {len(tasks_to_run)} 条新任务待处理。开始多线程评测...")

    with concurrent.futures.ThreadPoolExecutor(
        max_workers=MAX_WORKERS
    ) as executor, open(OUTPUT_FILE, "a", encoding="utf-8") as outfile:

        judge_func_with_client = functools.partial(get_llm_judge_verdict, client)
        results_iterator = executor.map(judge_func_with_client, tasks_to_run)

        for result in tqdm(results_iterator, total=len(tasks_to_run), desc="评测进度"):
            outfile.write(json.dumps(result, ensure_ascii=False) + "\n")
            outfile.flush()

    logging.info(f"\n评测完成！结果已保存至 '{OUTPUT_FILE}'。")


if __name__ == "__main__":
    main()
