import json
from collections import defaultdict
from itertools import combinations
import argparse
import random
from decimal import Decimal

def convert_decimals_to_float(obj):
    """
    递归地将数据结构中的Decimal对象转换为float。
    """
    if isinstance(obj, list):
        return [convert_decimals_to_float(i) for i in obj]
    if isinstance(obj, dict):
        return {k: convert_decimals_to_float(v) for k, v in obj.items()}
    if isinstance(obj, Decimal):
        return float(obj)
    return obj

def random_sample_reward_pairs(input_file, output_file, k=10, seed=42):
    """
    读取all_merged_results.json，对每个reward模型的分数进行归一化，
    然后计算模型两两配对，并为每个模型对随机选取k个样本，最后保存结果。
    
    Args:
        input_file: 输入的JSON文件路径
        output_file: 输出的JSON文件路径
        k: 每个模型对随机选取的样本数量，默认10
        seed: 随机种子，用于保证结果可复现，默认42
    """
    print("开始处理reward模型分数...")
    
    # 设置随机种子
    random.seed(seed)
    
    try:
        with open(input_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        print(f"成功加载数据，共{data['total_questions']}个问题，{data['total_models']}个模型")
        
        # 第一遍：收集每个模型的所有分数，用于计算最小值和最大值
        model_scores = defaultdict(list)
        
        for question_data in data['data']:
            for answer in question_data['answers']:
                scores = answer.get('scores', {})
                for model_name, score in scores.items():
                    model_scores[model_name].append(score)
        
        # 计算每个模型的最小值和最大值
        model_stats = {}
        for model_name, scores in model_scores.items():
            if scores:
                min_score = min(scores)
                max_score = max(scores)
                model_stats[model_name] = {"min": min_score, "max": max_score}
                print(f"模型 {model_name}: 最小值={min_score:.4f}, 最大值={max_score:.4f}")
        
        # 第二遍：归一化分数并按模型对分类
        classified_data = defaultdict(list)
        
        for question_idx, question_data in enumerate(data['data']):
            question = question_data['question']
            answers = question_data['answers']
            
            # 对于每对答案，计算归一化分数
            for i in range(len(answers)):
                for j in range(i + 1, len(answers)):
                    answer_a = answers[i]
                    answer_b = answers[j]
                    
                    # 获取原始分数
                    scores_a = answer_a.get('scores', {})
                    scores_b = answer_b.get('scores', {})
                    
                    # 归一化分数
                    normalized_scores_a = {}
                    normalized_scores_b = {}
                    
                    for model_name in scores_a.keys():
                        if model_name in scores_b and model_name in model_stats:
                            stats = model_stats[model_name]
                            score_range = stats["max"] - stats["min"]
                            
                            if score_range > 0:
                                # 归一化到[0, 1]范围
                                norm_a = (scores_a[model_name] - stats["min"]) / score_range
                                norm_b = (scores_b[model_name] - stats["min"]) / score_range
                            else:
                                # 如果所有分数相同，设为0
                                norm_a = 0.0
                                norm_b = 0.0
                            
                            normalized_scores_a[model_name] = norm_a
                            normalized_scores_b[model_name] = norm_b
                    
                    # 计算每个模型的偏好分数（答案A - 答案B）
                    preference_scores = {}
                    for model_name in normalized_scores_a.keys():
                        preference_scores[model_name] = normalized_scores_a[model_name] - normalized_scores_b[model_name]
                    
                    # 获取所有模型名称并排序
                    model_names = sorted(list(preference_scores.keys()))
                    
                    # 对每一对模型计算差异值并创建条目
                    for model_x, model_y in combinations(model_names, 2):
                        pair_key = f"{model_x}_vs_{model_y}"
                        
                        pref_x = preference_scores[model_x]
                        pref_y = preference_scores[model_y]
                        
                        # 计算偏好分数差异的绝对值（虽然不用于筛选，但保留用于分析）
                        score_discrepancy = abs(pref_x - pref_y)
                        
                        # 构建数据条目，只保留当前比较的两个模型的归一化分数
                        entry = {
                            "question": question,
                            "response_pair": {
                                "response_a": {
                                    "answer": answer_a['answer'],
                                    "index": answer_a['index'],
                                    "normalized_scores": {
                                        model_x: normalized_scores_a.get(model_x),
                                        model_y: normalized_scores_a.get(model_y)
                                    }
                                },
                                "response_b": {
                                    "answer": answer_b['answer'],
                                    "index": answer_b['index'],
                                    "normalized_scores": {
                                        model_x: normalized_scores_b.get(model_x),
                                        model_y: normalized_scores_b.get(model_y)
                                    }
                                }
                            },
                            "preference_scores": {
                                model_x: pref_x,
                                model_y: pref_y
                            },
                            "preference_score_discrepancy": score_discrepancy,
                            "question_index": question_idx
                        }
                        
                        classified_data[pair_key].append(entry)
            
            if (question_idx + 1) % 100 == 0:
                print(f"已处理 {question_idx + 1}/{len(data['data'])} 个问题")
        
        # 对每个模型对随机选取k个样本
        print(f"\n开始为每个模型对随机选取{k}个样本...")
        random_k_results = {}
        
        for pair_key, entries in classified_data.items():
            # 随机选择k个样本
            if len(entries) <= k:
                # 如果样本数不足k个，取全部
                selected_entries = entries.copy()
            else:
                # 随机选择k个样本
                selected_entries = random.sample(entries, k)
            
            random_k_results[pair_key] = selected_entries
            print(f"  {pair_key}: 从{len(entries)}个样本中随机选取了{len(selected_entries)}个")
        
        # 转换Decimal对象为float
        final_results = convert_decimals_to_float(random_k_results)
        
        # 保存结果
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(final_results, f, indent=2, ensure_ascii=False)
        
        print(f"\n成功处理完成！结果已保存到 '{output_file}'")
        print(f"共生成 {len(final_results)} 个模型对的数据")
        print(f"使用随机种子: {seed}")
        
        # 打印统计信息
        for pair_key, entries in final_results.items():
            print(f"  {pair_key}: {len(entries)} 个样本")
    
    except FileNotFoundError:
        print(f"错误：找不到输入文件 '{input_file}'")
    except json.JSONDecodeError as e:
        print(f"错误：解析JSON文件 '{input_file}' 时出错: {e}")
    except Exception as e:
        print(f"发生意外错误: {e}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='处理reward模型分数，为每个模型对随机选取k个样本')
    parser.add_argument('--input', type=str, 
                        default="/root/gMad/3_reward_score/all_merged_results_5.json",
                        help='输入的JSON文件路径')
    parser.add_argument('--output', type=str, 
                        default="/root/gMad/4_oracle_judge/random/random_sampled_reward_pairs_5.json",
                        help='输出的随机采样结果JSON文件路径')
    parser.add_argument('--k', type=int, default=10,
                        help='每个模型对随机选取的样本数量，默认10')
    parser.add_argument('--seed', type=int, default=42,
                        help='随机种子，用于保证结果可复现，默认42')
    
    args = parser.parse_args()
    
    # 执行处理流程
    random_sample_reward_pairs(args.input, args.output, k=args.k, seed=args.seed)
