#!/usr/bin/env python3
"""
分析每张图片的具体错误情况
"""

import json
import numpy as np

def analyze_per_image_errors():
    """分析每张图片平均的漏检和误检情况"""
    
    # 读取评测结果
    with open('./evaluation_results/ade20k_evaluation_val.json', 'r') as f:
        results = json.load(f)
    
    # 从整体指标计算
    overall = results['overall']
    metadata = results['metadata']
    
    num_images = metadata['num_images']  # 2000张图片
    num_classes = metadata['num_classes']  # 150个类别
    
    # 从micro指标计算总的预测和真实标签数
    precision_micro = overall['precision_micro']  # 0.8081
    recall_micro = overall['recall_micro']  # 0.6128
    f1_micro = overall['f1_micro']  # 0.6970
    hamming_loss = overall['hamming_loss']  # 0.0300
    
    print("🔍 每张图片错误分析:")
    print("=" * 50)
    
    # 计算总的标签数量
    total_labels = num_images * num_classes  # 总共可能的标签位置
    total_errors = hamming_loss * total_labels  # 总错误数
    errors_per_image = total_errors / num_images
    
    print(f"平均每张图片错误标签数: {errors_per_image:.2f} 个")
    
    # 从per_class数据计算更精确的统计
    per_class = results['per_class']
    
    total_tp = sum([cls_data['tp'] for cls_data in per_class.values()])
    total_fp = sum([cls_data['fp'] for cls_data in per_class.values()])
    total_fn = sum([cls_data['fn'] for cls_data in per_class.values()])
    total_tn = sum([cls_data['tn'] for cls_data in per_class.values()])
    
    # 每张图片的平均情况
    avg_tp_per_image = total_tp / num_images
    avg_fp_per_image = total_fp / num_images  
    avg_fn_per_image = total_fn / num_images
    avg_tn_per_image = total_tn / num_images
    
    print(f"\n📊 平均每张图片:")
    print(f"  正确预测 (TP): {avg_tp_per_image:.2f} 个类别")
    print(f"  误检 (FP):     {avg_fp_per_image:.2f} 个类别  ← 错误预测为存在")
    print(f"  漏检 (FN):     {avg_fn_per_image:.2f} 个类别  ← 遗漏的真实类别")
    print(f"  正确拒绝 (TN): {avg_tn_per_image:.2f} 个类别")
    
    print(f"\n🎯 关键错误指标:")
    print(f"  平均每张图片漏掉: {avg_fn_per_image:.2f} 个真实存在的类别")
    print(f"  平均每张图片误判: {avg_fp_per_image:.2f} 个不存在的类别")
    print(f"  总错误数 = 漏检 + 误检 = {avg_fn_per_image + avg_fp_per_image:.2f} 个")
    
    # 计算每张图片平均有多少个真实类别
    total_positive_samples = sum([cls_data['positive_samples'] for cls_data in per_class.values()])
    avg_positive_per_image = total_positive_samples / num_images
    
    print(f"\n📈 数据统计:")
    print(f"  平均每张图片有: {avg_positive_per_image:.2f} 个真实类别")
    print(f"  召回率: {recall_micro:.1%} (找到了 {recall_micro*100:.1f}% 的真实类别)")
    print(f"  精确率: {precision_micro:.1%} (预测为正的有 {precision_micro*100:.1f}% 是对的)")
    
    # 漏检率和误检率
    miss_rate = avg_fn_per_image / avg_positive_per_image if avg_positive_per_image > 0 else 0
    print(f"\n❌ 错误率分析:")
    print(f"  漏检率: {miss_rate:.1%} (平均漏掉 {miss_rate*100:.1f}% 的真实类别)")
    print(f"  每个真实类别的平均误检数: {avg_fp_per_image/avg_positive_per_image:.2f}")

if __name__ == "__main__":
    analyze_per_image_errors()