#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
select_best_model_v3.py
### 把每个问题route出一个最好的model
功能：
1. 读取 ground_truth 文件
2. 每个模型单独配置 base_path、默认 time、默认 tokens
3. 根据关键词拼接对应 JSON 文件
4. 逐条对比并筛选最优模型
   - 先筛选 judge_result == label 的模型
   - 如果没有，则统计三类结果数量，保留最多的那类
   - 然后按 time -> tokens 逐步筛选
5. 打印第一条 question 的调试信息
"""

import json
from pathlib import Path

# ========== 配置区域 ==========
GROUND_TRUTH_PATH = Path()
KEYWORD = "writing"   # 可修改s
OUTPUT_PATH = Path()

# 模型配置（保持你的内容不动）
MODELS_CONFIG = {
    "claude-3.5-sonnet-20241022": {
        "base_path": Path("),
        "default_time": 999999,
        "default_tokens": 999999,
    },
    "gemma-3-4b-it": {
        "base_path": Path(),
        "default_time": 888888,
        "default_tokens": 888888,
    },
    "gemma-3-27b-it": {
        "base_path": Path(),
        "default_time": 777777,
        "default_tokens": 777777,
    },
    "gpt-3.5-turbo-1106": {
        "base_path": Path(),
        "default_time": 999999,
        "default_tokens": 999999,
    },
    "gpt4o": {
        "base_path": Path(),
        "default_time": 999999,
        "default_tokens": 999999,
    },
    "Llama-3.1-8B-Instruct": {
        "base_path": Path(),
        "default_time": 999999,
        "default_tokens": 999999,
    },
    "llama3.3-70B-instruct": {
        "base_path": Path(),
        "default_time": 999999,
        "default_tokens": 999999,
    },
    "Qwen2.5-7B-Instruct": {
        "base_path": Path(),
        "default_time": 999999,
        "default_tokens": 999999,
    },
    "qwen2.5-72b-instruct": {
        "base_path": Path(),
        "default_time": 999999,
        "default_tokens": 999999,
    },
    "Qwen3-32B": {
        "base_path": Path(),
        "default_time": 999999,
        "default_tokens": 999999,
    },
    "claude-3.7-sonnet-thinking": {
        "base_path": Path(),
        "default_time": 22.74,       ###记得改这里！！！！！！！！！！！！！！！！！！！！！！！！
        "default_tokens": 999999,
    },
    "deepseek-r1": {
        "base_path": Path(),
        "default_time": 10.5,          ###记得改这里！！！！！！！！！！！！！！！！！！！！！！！！
        "default_tokens": 1786.8,      ###记得改这里！！！！！！！！！！！！！！！！！！！！！！！！
    },
    "doubao-1-5-thinking-pro-250415": {
        "base_path": Path(),
        "default_time": 27.02,         ###记得改这里！！！！！！！！！！！！！！！！！！！！！！！！
        "default_tokens": 2334,      ###记得改这里！！！！！！！！！！！！！！！！！！！！！！！！
    },
    "gemini-2.5-flash": {
        "base_path": Path(),
        "default_time": 9.97,         ###记得改这里！！！！！！！！！！！！！！！！！！！！！！！！
        "default_tokens": 3105.6,     ###记得改这里！！！！！！！！！！！！！！！！！！！！！！！！
    },
    "gemini-2.5-pro": {
        "base_path": Path(),
        "default_time": 999999,
        "default_tokens": 999999,
    },
    "GPT5": {
        "base_path": Path(),
        "default_time": 999999,
        "default_tokens": 999999,
    },
    "Kimi-K2-Instruct": {
        "base_path": Path(),
        "default_time": 999999,
        "default_tokens": 999999,
    },
    "o3mini": {
        "base_path": Path(),
        "default_time": 999999,
        "default_tokens": 999999,
    },
    "QwQ-32B": {
        "base_path": Path(),
        "default_time": 27.21,             ###记得改这里！！！！！！！！！！！！！！！！！！！！！！！！
        "default_tokens": 2490.8,         ###记得改这里！！！！！！！！！！！！！！！！！！！！！！！！
    },
}

# ==========================================

def load_json(path: Path):
    if not path.exists():
        return []
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def main():
    # 读取 ground truth
    ground_truth = load_json(GROUND_TRUTH_PATH)

    # 读取所有模型文件
    model_data = {}
    for model, cfg in MODELS_CONFIG.items():
        path = cfg["base_path"] / f"{KEYWORD}.json"
        model_data[model] = load_json(path)

    results = []

    for idx, gt_item in enumerate(ground_truth):
        question = gt_item.get("question")
        label = gt_item.get("label")

        all_models_info = []  # 收集所有模型原始信息
        candidates = []       # 第一轮筛选结果

        # 遍历所有模型
        for model, cfg in MODELS_CONFIG.items():
            data_list = model_data.get(model, [])
            matched = next((d for d in data_list if d.get("question") == question), None)
            if not matched:
                continue

            judge_result = matched.get("judge result", matched.get("judge_result"))
            judge_model = matched.get("judge model")
            time_val = matched.get("time", cfg["default_time"])
            total_tokens = matched.get("cost_tokens", {}).get("total_tokens", cfg["default_tokens"])

            info = {
                "judge_model": judge_model,
                "judge_result": judge_result,
                "time": time_val,
                "tokens": total_tokens,
            }
            all_models_info.append(info)

            if str(judge_result).strip() == str(label).strip():
                candidates.append(info)

        # 如果没有符合 label 的结果，则用多数投票保留一类
        if not candidates and all_models_info:
            counts = {}
            for m in all_models_info:
                jr = str(m["judge_result"]).strip()
                counts[jr] = counts.get(jr, 0) + 1
            # 找数量最多的那类
            majority_result = max(counts.items(), key=lambda x: x[1])[0]
            candidates = [m for m in all_models_info if str(m["judge_result"]).strip() == majority_result]

        # 开始筛选
        best_model = None
        if candidates:
            min_time = min(c["time"] for c in candidates)
            cands_time = [c for c in candidates if c["time"] == min_time]

            if len(cands_time) == 1:
                best_model = cands_time[0]
            else:
                min_tokens = min(c["tokens"] for c in cands_time)
                cands_tokens = [c for c in cands_time if c["tokens"] == min_tokens]
                best_model = cands_tokens[0]

        # 🔍 打印第一条的调试信息
        if idx == 0:
            print("=" * 60)
            print(f"Question: {question}")
            print(f"Label: {label}\n")
            print("所有模型信息：")
            for m in all_models_info:
                print(m)
            print("\n最终选中的模型：")
            print(best_model if best_model else "None")
            print("=" * 60)

        # 写入结果
        results.append({
            "instruction": "xxxxx",
            "input": question,
            "output": best_model["judge_model"] if best_model else "None"
        })

    # 导出 JSON
    with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)

    print(f"结果已保存到: {OUTPUT_PATH}")

if __name__ == "__main__":
    main()
