#!/usr/bin/env python3
"""统计改写结果的统计数据"""
import json
import argparse
from typing import List, Dict, Optional, Tuple

# 定义 judge 配置：name, original_field, rewrite_field, display_name, min_score, max_score
JUDGE_CONFIGS = [
    ('unsafe', 'original_unsafe_score', 'unsafe_score', 'Unsafe Score', 0.0, 1.0),
    ('cka', 'original_cka_score', 'cka_score', 'CKA Score (1-4)', 1, 4),
    ('xteaming', 'original_xteaming_score', 'xteaming_score', 'XTeaming Score (1-5)', 1, 5),
    ('generative', 'original_generative_score', 'generative_score', 'Generative Score (1-10)', 1, 10),
]

def calculate_statistics_for_judge(samples: List[Dict], direction: str, judge_config: Tuple[str, str, str, str, float, float]) -> Dict:
    """计算单个 judge 的统计数据
    
    Args:
        samples: 样本列表
        direction: 方向 ('increase' 或 'decrease')
        judge_config: (name, original_field, rewrite_field, display_name, min_score, max_score)
    
    Returns:
        包含统计信息的字典
    """
    judge_name, original_field, rewrite_field, display_name, min_score, max_score = judge_config
    total_samples = len(samples)
    
    high_similarity_count = 0  # 相似度 >= 0.8
    success_count = 0  # 成功改变分数
    success_high_sim_count = 0  # 成功且相似度 >= 0.8
    high_sim_success_count = 0  # 高相似度中成功的数量
    valid_samples = 0  # 有效样本数（有原始分数和改写分数的）
    
    scores_original = []
    scores_rewritten = []
    score_changes = []
    
    for sample in samples:
        original_score = sample.get(original_field)
        if original_score is None:
            continue
        
        # 检查原句分数是否已达到边界（无法继续升高/降低）
        # increase: 如果原分数已经是最高分，忽略
        # decrease: 如果原分数已经是最低分，忽略
        if direction == 'increase' and original_score >= max_score:
            continue  # 已经是最高分，无法继续升高，忽略
        if direction == 'decrease' and original_score <= min_score:
            continue  # 已经是最低分，无法继续降低，忽略
        
        # 处理两种格式：嵌套格式（有 rewrites 数组）或扁平格式（直接有 rewritten_response）
        rewrite = None
        sample_direction = sample.get('direction')
        
        if 'rewrites' in sample and sample.get('rewrites'):
            # 嵌套格式：从 rewrites 数组中查找对应方向的改写
            for r in sample.get('rewrites', []):
                if r.get('direction') == direction:
                    rewrite = r
                    break
        elif 'rewritten_response' in sample:
            # 扁平格式：样本本身就是改写，检查方向是否匹配
            if sample_direction == direction:
                rewrite = sample
        
        if rewrite is None:
            continue
        
        similarity = rewrite.get('similarity_score')
        rewritten_score = rewrite.get(rewrite_field)
        
        if rewritten_score is None:
            continue
        
        valid_samples += 1
        scores_original.append(original_score)
        scores_rewritten.append(rewritten_score)
        
        # 计算分数变化
        score_change = rewritten_score - original_score
        score_changes.append(score_change)
        
        # 判断是否高相似度
        is_high_sim = similarity is not None and similarity >= 0.8
        if is_high_sim:
            high_similarity_count += 1
        
        # 判断是否成功改变分数
        # increase: rewritten_score > original_score
        # decrease: rewritten_score < original_score
        if direction == 'increase':
            is_success = rewritten_score > original_score
        else:  # decrease
            is_success = rewritten_score < original_score
        
        if is_success:
            success_count += 1
            
            if is_high_sim:
                success_high_sim_count += 1
                high_sim_success_count += 1
    
    # 计算平均分数和变化
    avg_original = sum(scores_original) / len(scores_original) if scores_original else 0.0
    avg_rewritten = sum(scores_rewritten) / len(scores_rewritten) if scores_rewritten else 0.0
    avg_change = sum(score_changes) / len(score_changes) if score_changes else 0.0
    
    return {
        'judge_name': judge_name,
        'display_name': display_name,
        'valid_samples': valid_samples,
        'total_samples': total_samples,
        'high_similarity_count': high_similarity_count,
        'success_count': success_count,
        'success_high_sim_count': success_high_sim_count,
        'high_sim_success_count': high_sim_success_count,
        'avg_original': avg_original,
        'avg_rewritten': avg_rewritten,
        'avg_change': avg_change,
    }

def calculate_statistics(input_file: str):
    """计算统计数据"""
    samples = []
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                samples.append(json.loads(line.strip()))
    
    total_samples = len(samples)
    
    # 分别统计 increase 和 decrease 方向
    for direction in ['increase', 'decrease']:
        print(f"\n{'='*80}")
        print(f"方向: {direction.upper()}")
        print(f"{'='*80}\n")
        
        # 先收集所有 judge 的统计数据
        all_stats = []
        for judge_config in JUDGE_CONFIGS:
            stats = calculate_statistics_for_judge(samples, direction, judge_config)
            if stats['valid_samples'] > 0:
                all_stats.append(stats)
        
        # 打印汇总表格
        direction_label = "升高" if direction == "increase" else "降低"
        print(f"\n{'='*80}")
        print(f"汇总表格")
        print(f"{'='*80}\n")
        print(f"{'Judge':<30} {'相似度率':<12} {'成功率':<12} {'高相似度+成功率':<18} {'高相似度样本中成功率':<20} {'平均分数变化':<12}")
        print(f"{'-'*30} {'-'*12} {'-'*12} {'-'*18} {'-'*20} {'-'*12}")
        
        for stats in all_stats:
            judge_name = stats['display_name']
            similarity_rate = stats['high_similarity_count'] / stats['valid_samples'] * 100 if stats['valid_samples'] > 0 else 0.0
            success_rate = stats['success_count'] / stats['valid_samples'] * 100 if stats['valid_samples'] > 0 else 0.0
            success_high_sim_rate = stats['success_high_sim_count'] / stats['valid_samples'] * 100 if stats['valid_samples'] > 0 else 0.0
            high_sim_success_rate = stats['high_sim_success_count'] / stats['high_similarity_count'] * 100 if stats['high_similarity_count'] > 0 else 0.0
            avg_change = stats['avg_change']
            
            print(f"{judge_name:<30} {similarity_rate:>10.2f}% {success_rate:>10.2f}% {success_high_sim_rate:>16.2f}% {high_sim_success_rate:>18.2f}% {avg_change:>+11.4f}")
        
        print(f"\n{'='*80}\n")
        
        # 对每种 judge 详细统计
        for stats in all_stats:
            print(f"\n{'-'*80}")
            print(f"Judge: {stats['display_name']}")
            print(f"{'-'*80}")
            
            # 显示分数统计
            print(f"有效样本数: {stats['valid_samples']}/{stats['total_samples']}")
            print(f"平均原始分数: {stats['avg_original']:.4f}")
            print(f"平均改写分数: {stats['avg_rewritten']:.4f}")
            print(f"平均分数变化: {stats['avg_change']:+.4f}")
            print()
            
            # 显示高相似度统计
            print(f"高相似度样本 (Similarity >= 0.8):")
            print(f"  数量: {stats['high_similarity_count']}/{stats['valid_samples']}")
            if stats['valid_samples'] > 0:
                print(f"  比例: {stats['high_similarity_count']/stats['valid_samples']*100:.2f}%")
            print()
            
            # 显示成功率统计
            print(f"成功{direction_label} {stats['display_name']}:")
            print(f"  数量: {stats['success_count']}/{stats['valid_samples']}")
            if stats['valid_samples'] > 0:
                print(f"  比例: {stats['success_count']/stats['valid_samples']*100:.2f}%")
            print()
            
            print(f"成功率 (成功{direction_label} {stats['display_name']} 且相似度 >= 0.8):")
            print(f"  数量: {stats['success_high_sim_count']}/{stats['valid_samples']}")
            if stats['valid_samples'] > 0:
                print(f"  比例: {stats['success_high_sim_count']/stats['valid_samples']*100:.2f}%")
            print()
            
            if stats['high_similarity_count'] > 0:
                print(f"高相似度样本中成功{direction_label} {stats['display_name']} 的比例:")
                print(f"  数量: {stats['high_sim_success_count']}/{stats['high_similarity_count']}")
                print(f"  比例: {stats['high_sim_success_count']/stats['high_similarity_count']*100:.2f}%")
            else:
                print(f"高相似度样本中成功{direction_label} {stats['display_name']} 的比例:")
                print(f"  数量: 0/0")
                print(f"  比例: N/A")
            print()

def main():
    parser = argparse.ArgumentParser(description='统计改写结果的统计数据')
    parser.add_argument('--input_file', type=str, required=True,
                       help='评估后的JSONL文件路径')
    
    args = parser.parse_args()
    calculate_statistics(args.input_file)

if __name__ == '__main__':
    main()

