#!/usr/bin/env python3
"""
合并所有 reward 模型的评分结果到一个文件中。
将所有问题和答案组织在一起，每个回答包含所有模型的分数。
"""

import json
import os
import glob
from collections import defaultdict
from typing import Dict, List, Any

def load_json_file(filepath: str) -> Dict[str, Any]:
    """加载 JSON 文件"""
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            return json.load(f)
    except Exception as e:
        print(f"错误: 无法加载文件 {filepath}: {e}")
        return {}

def merge_reward_results(results_dir: str, output_file: str):
    """
    合并所有 reward 模型的结果
    
    Args:
        results_dir: 包含所有 JSON 文件的目录
        output_file: 输出文件路径
    """
    
    # 获取所有 JSON 文件
    json_files = glob.glob(os.path.join(results_dir, "*.json"))
    
    if not json_files:
        print(f"错误: 在目录 {results_dir} 中没有找到 JSON 文件")
        return
    
    print(f"找到 {len(json_files)} 个 JSON 文件")
    
    # 用于存储合并后的数据
    # 结构: {question: {answer: {models: {model_name: score}}}}
    merged_data = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    
    # 存储所有模型名称
    all_models = set()
    
    # 处理每个文件
    for json_file in json_files:
        print(f"正在处理: {os.path.basename(json_file)}")
        
        data = load_json_file(json_file)
        if not data:
            continue
            
        model_name = data.get("model_name", os.path.basename(json_file).replace(".json", ""))
        all_models.add(model_name)
        
        scores = data.get("scores", [])
        
        for score_entry in scores:
            question = score_entry.get("question", "")
            answer = score_entry.get("answer", "")
            score = score_entry.get("score", 0)
            index = score_entry.get("index", 0)
            
            # 使用问题和答案的组合作为键
            question_key = question.strip()
            answer_key = answer.strip()
            
            # 存储分数和索引
            merged_data[question_key][answer_key]["models"][model_name] = score
            merged_data[question_key][answer_key]["index"] = index
    
    # 转换为最终格式
    final_results = {
        "total_questions": len(merged_data),
        "total_models": len(all_models),
        "models": sorted(list(all_models)),
        "data": []
    }
    
    # 按问题组织数据
    for question, answers_dict in merged_data.items():
        question_data = {
            "question": question,
            "answers": []
        }
        
        # 按索引排序答案
        sorted_answers = sorted(answers_dict.items(), key=lambda x: x[1].get("index", 0))
        
        for answer, answer_data in sorted_answers:
            answer_entry = {
                "answer": answer,
                "index": answer_data.get("index", 0),
                "scores": {}
            }
            
            # 添加所有模型的分数
            for model in all_models:
                answer_entry["scores"][model] = answer_data["models"].get(model, None)
            
            question_data["answers"].append(answer_entry)
        
        final_results["data"].append(question_data)
    
    # 按第一个答案的索引排序问题
    final_results["data"].sort(key=lambda x: x["answers"][0]["index"] if x["answers"] else 0)
    
    # 保存结果
    try:
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(final_results, f, ensure_ascii=False, indent=2)
        
        print(f"\n合并完成!")
        print(f"- 总问题数: {final_results['total_questions']}")
        print(f"- 总模型数: {final_results['total_models']}")
        print(f"- 输出文件: {output_file}")
        print(f"- 模型列表: {', '.join(final_results['models'])}")
        
    except Exception as e:
        print(f"错误: 无法保存文件 {output_file}: {e}")

def main():
    """主函数"""
    results_dir = "/root/gMad/3_reward_score/results_10"
    output_file = "/root/gMad/3_reward_score/all_merged_results_10.json"
    
    print("开始合并 reward 模型评分结果...")
    print(f"输入目录: {results_dir}")
    print(f"输出文件: {output_file}")
    print("-" * 50)
    
    merge_reward_results(results_dir, output_file)

if __name__ == "__main__":
    main()
