#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
将Reward模型的两两比较结果转换为全局排名
基于classified_reward_scores_by_pair.json文件，使用Bradley-Terry模型计算全局排名
"""

import argparse
import json
import numpy as np
from pathlib import Path
from typing import Any, Dict, List, Tuple
from collections import defaultdict
from scipy.optimize import minimize


class RewardPairwiseToGlobalRanking:
    """从Reward模型两两排名转换为全局排名的处理器"""
    
    def __init__(self, regularization: float = 1e-6):
        self.regularization = regularization
        self.reward_model_names: List[str] = []
    
    def load_classified_data(self, file_path: Path) -> Dict[str, Any]:
        """加载 classified_reward_scores_by_pair.json 文件"""
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        print(f"加载了 {len(data)} 个模型对的分类数据")
        
        # 统计总样本数
        total_samples = sum(len(samples) for samples in data.values())
        oracle_judged_samples = sum(
            len([s for s in samples if s.get('oracle_preference') is not None])
            for samples in data.values()
        )
        print(f"总样本数: {total_samples}")
        print(f"已进行Oracle判断的样本数: {oracle_judged_samples}")
        
        return data
    
    def extract_pairwise_comparisons(self, classified_data: Dict[str, Any]) -> Dict[str, Any]:
        """从分类数据中提取两两比较数据"""
        print("从分类数据中提取两两比较数据...")
        
        # 收集所有参与的reward模型
        all_reward_models = set()
        for model_pair_key in classified_data.keys():
            try:
                model_a, model_b = model_pair_key.split('_vs_')
                all_reward_models.add(model_a)
                all_reward_models.add(model_b)
            except ValueError:
                print(f"警告: 无法解析模型对密钥 '{model_pair_key}'")
        
        self.reward_model_names = sorted(list(all_reward_models))
        print(f"涉及的Reward模型 ({len(self.reward_model_names)}个): {self.reward_model_names}")
        
        # 构建Bradley-Terry所需的胜负矩阵
        n_models = len(self.reward_model_names)
        wins = np.zeros((n_models, n_models))
        total = np.zeros((n_models, n_models))
        
        # 统计每个模型与Oracle的一致性
        oracle_accuracy = {name: {'correct': 0, 'total': 0} for name in self.reward_model_names}
        total_comparisons = 0
        processed_samples = 0
        
        for model_pair_key, samples in classified_data.items():
            try:
                model1, model2 = model_pair_key.split('_vs_')
            except ValueError:
                continue

            if model1 not in self.reward_model_names or model2 not in self.reward_model_names:
                continue

            for sample in samples:
                oracle_pref = sample.get('oracle_preference')
                if oracle_pref is None or oracle_pref not in [1, 2]:
                    continue
                
                processed_samples += 1
                total_comparisons += 1
                
                # 获取偏好分数 (preference_scores)
                preference_scores = sample.get('preference_scores', {})
                pref1 = preference_scores.get(model1)
                pref2 = preference_scores.get(model2)

                if pref1 is None or pref2 is None:
                    continue
                
                # 基于偏好分数进行预测
                # 偏好分数 > 0 表示偏好response_a，< 0 表示偏好response_b
                model1_prediction = 1 if pref1 > pref2 else 2
                model2_prediction = 1 if pref2 > pref1 else 2
                
                # 更新与Oracle的一致性统计
                oracle_accuracy[model1]['total'] += 1
                if model1_prediction == oracle_pref:
                    oracle_accuracy[model1]['correct'] += 1
                
                oracle_accuracy[model2]['total'] += 1
                if model2_prediction == oracle_pref:
                    oracle_accuracy[model2]['correct'] += 1

                # 更新胜负矩阵
                i = self.reward_model_names.index(model1)
                j = self.reward_model_names.index(model2)
                
                total[i][j] += 1
                total[j][i] += 1
                
                model1_correct = (model1_prediction == oracle_pref)
                model2_correct = (model2_prediction == oracle_pref)

                if model1_correct and not model2_correct:
                    wins[i][j] += 1
                elif not model1_correct and model2_correct:
                    wins[j][i] += 1
                elif model1_correct and model2_correct:
                    # 都正确时，各得0.5分
                    wins[i][j] += 0.5
                    wins[j][i] += 0.5
                # 都错误时，不加分
        
        print(f"处理了 {processed_samples} 个有Oracle判断的样本")
        
        # 计算最终的准确率
        accuracy_scores = {}
        for model_name, stats in oracle_accuracy.items():
            if stats['total'] > 0:
                accuracy_scores[model_name] = stats['correct'] / stats['total']
            else:
                accuracy_scores[model_name] = 0.0
        
        return {
            'reward_model_names': self.reward_model_names,
            'wins': wins,
            'total': total,
            'oracle_accuracy': accuracy_scores,
            'total_samples': total_comparisons
        }
    
    def bradley_terry_log_likelihood(self, xi: np.ndarray, wins: np.ndarray, total: np.ndarray) -> float:
        """计算Bradley-Terry模型的负对数似然"""
        n_models = len(xi)
        log_likelihood = 0.0
        
        for i in range(n_models):
            for j in range(n_models):
                if i != j and total[i][j] > 0:
                    prob_i_beats_j = 1.0 / (1.0 + np.exp(xi[j] - xi[i]))
                    prob_i_beats_j = np.clip(prob_i_beats_j, 1e-15, 1-1e-15)
                    prob_j_beats_i = 1.0 - prob_i_beats_j
                    
                    if wins[i][j] > 0:
                        log_likelihood += wins[i][j] * np.log(prob_i_beats_j)
                    if total[i][j] - wins[i][j] > 0:
                        log_likelihood += (total[i][j] - wins[i][j]) * np.log(prob_j_beats_i)
        
        regularization_term = self.regularization * np.sum(xi[1:] ** 2)
        return float(-(log_likelihood - regularization_term))
    
    def fit_bradley_terry_model(self, comparison_data: Dict[str, Any]) -> Dict[str, Any]:
        """拟合Bradley-Terry模型"""
        print("拟合Bradley-Terry模型...")
        
        reward_model_names = comparison_data['reward_model_names']
        wins = comparison_data['wins']
        total = comparison_data['total']
        oracle_accuracy = comparison_data['oracle_accuracy']
        
        n_models = len(reward_model_names)
        if n_models < 2:
            raise ValueError("需要至少2个reward模型来计算排名")
        
        initial_xi = np.zeros(n_models)
        
        def objective(xi_free):
            xi = np.zeros(n_models)
            xi[1:] = xi_free
            return self.bradley_terry_log_likelihood(xi, wins, total)
        
        try:
            result = minimize(
                objective, initial_xi[1:], method='BFGS', options={'maxiter': 1000}
            )
            if not result.success:
                print(f"警告: Bradley-Terry优化未收敛: {result.message}")
            
            bt_coefficients = np.zeros(n_models)
            bt_coefficients[1:] = result.x
            
        except Exception as e:
            print(f"Bradley-Terry优化失败: {e}")
            print("使用Oracle准确率作为替代排名依据")
            accuracy_values = [oracle_accuracy[name] for name in reward_model_names]
            bt_coefficients = np.array(accuracy_values)
            bt_coefficients = bt_coefficients - bt_coefficients[0]
        
        rankings = self.compute_rankings(bt_coefficients, reward_model_names)
        win_matrix = np.divide(wins, total, out=np.zeros_like(wins), where=total!=0)
        
        return {
            'reward_model_names': reward_model_names,
            'bt_coefficients': bt_coefficients.tolist(),
            'rankings': rankings,
            'oracle_accuracy': oracle_accuracy,
            'win_matrix': win_matrix.tolist(),
            'total_comparisons': comparison_data['total_samples']
        }
    
    def compute_rankings(self, bt_coefficients: np.ndarray, reward_model_names: List[str]) -> Dict[str, int]:
        """根据BT系数计算排名"""
        sorted_indices = np.argsort(-bt_coefficients)
        rankings = {}
        current_rank = 1
        for i, idx in enumerate(sorted_indices):
            if i > 0 and np.abs(bt_coefficients[idx] - bt_coefficients[sorted_indices[i-1]]) < 1e-12:
                rankings[reward_model_names[idx]] = rankings[reward_model_names[sorted_indices[i-1]]]
            else:
                rankings[reward_model_names[idx]] = current_rank
            current_rank = i + 2
        return rankings
    
    def process_classified_data(self, classified_data: Dict[str, Any]) -> Dict[str, Any]:
        """处理分类数据并生成全局排名"""
        comparison_data = self.extract_pairwise_comparisons(classified_data)
        ranking_result = self.fit_bradley_terry_model(comparison_data)
        return ranking_result


def main():
    parser = argparse.ArgumentParser(description='将Reward模型两两排名转换为全局排名')
    parser.add_argument('--input', type=str, 
                    default='/root/gMad/4_oracle_judge/random_result/classified_reward_scores_by_pair_5.json',
                       help='输入文件路径 (classified_reward_scores_by_pair.json)')
    parser.add_argument('--output', type=str, default="/root/gMad/5_result_analyze/random/reward_global_ranking_result_5.json", 
                       help='输出文件路径，默认与输入同目录下的 reward_global_ranking_result.json')
    parser.add_argument('--regularization', type=float, default=1e-6,
                       help='Bradley-Terry模型的正则化参数，默认1e-6')
    args = parser.parse_args()
    
    input_path = Path(args.input)
    
    if not input_path.exists():
        print(f"错误: 输入文件不存在: {input_path}")
        return
    
    processor = RewardPairwiseToGlobalRanking(regularization=args.regularization)
    
    print("开始处理Reward模型排名...")
    print("=" * 60)
    
    classified_data = processor.load_classified_data(input_path)
    
    ranking_result = processor.process_classified_data(classified_data)
    
    output_path = Path(args.output) 
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(ranking_result, f, ensure_ascii=False, indent=2)
    
    print("\n" + "="*60)
    print("Reward模型全局排名结果 (Bradley-Terry模型):")
    print("="*60)
    
    reward_model_names = ranking_result['reward_model_names']
    bt_coefficients = ranking_result['bt_coefficients']
    rankings = ranking_result['rankings']
    oracle_accuracy = ranking_result['oracle_accuracy']
    
    sorted_items = sorted(rankings.items(), key=lambda x: x[1])
    
    for model_name, rank in sorted_items:
        idx = reward_model_names.index(model_name)
        bt_score = bt_coefficients[idx]
        oracle_acc = oracle_accuracy[model_name]
        print(f"{rank:2d}. {model_name}")
        print(f"    BT系数: {bt_score:6.3f} | Oracle一致性: {oracle_acc:.3f}")
    
    print(f"\n总比较样本数: {ranking_result['total_comparisons']}")
    print(f"结果已保存到: {output_path}")


if __name__ == '__main__':
    main()
