#!/usr/bin/env python3
"""
可视化模型评估结果
使用方法：python visualize_results.py [--input results.csv]
"""

import argparse
import pandas as pd
import sys


def print_separator(char="=", length=100):
    """打印分隔线"""
    print(char * length)


def print_category_header(category_name):
    """打印类别标题"""
    print(f"\n{'='*100}")
    print(f"  {category_name}")
    print(f"{'='*100}")


def format_value(value, is_percentage=False):
    """格式化数值显示"""
    if pd.isna(value):
        return "N/A"
    
    try:
        value = float(value)
        if is_percentage:
            if value < 1:  # 如果是0-1之间的小数，转换为百分比
                return f"{value*100:.2f}%"
            else:  # 如果已经是百分比格式
                return f"{value:.2f}%"
        else:
            return f"{value:.4f}"
    except:
        return str(value)


def print_comparison_table(df):
    """打印对比表格"""
    
    # 数据集分类
    math_datasets = ['MathVista_MINI', 'MathVision_MINI', 'MathVerse_MINI_Vision_Only', 'WeMath']
    spatial_datasets = ['3DSRBench', 'A-OKVQA', 'SpatialEval', 'RealWorldQA']
    hallucination_datasets = ['HallusionBench', 'POPE']
    
    # 数据集是否为百分比格式（>1的值）
    percentage_datasets = ['MathVista_MINI', 'MathVision_MINI', 'MathVerse_MINI_Vision_Only', 
                          'WeMath', 'HallusionBench', 'POPE']
    
    # Math Benchmarks
    print_category_header("📊 Math Benchmarks")
    print(f"\n{'Model':<40} {'MathVista':>12} {'MathVision':>12} {'MathVerse':>12} {'WeMath':>12}")
    print("-" * 100)
    
    for _, row in df.iterrows():
        model_name = row['Model']
        if len(model_name) > 38:
            model_name = model_name[:35] + "..."
        
        values = [format_value(row[ds], ds in percentage_datasets) for ds in math_datasets]
        print(f"{model_name:<40} {values[0]:>12} {values[1]:>12} {values[2]:>12} {values[3]:>12}")
    
    # 计算并显示平均值
    print("-" * 100)
    for ds in math_datasets:
        avg = df[ds].mean()
        short_name = ds.replace('_MINI', '').replace('_Vision_Only', '')[:12]
        print(f"{'Average':<40} {format_value(avg, ds in percentage_datasets):>12}", end="")
        if ds == math_datasets[-1]:
            print()
    
    # Spatial Reasoning Benchmarks
    print_category_header("🌍 Spatial Reasoning Benchmarks")
    print(f"\n{'Model':<40} {'3DSRBench':>12} {'A-OKVQA':>12} {'SpatialEval':>12} {'RealWorldQA':>12}")
    print("-" * 100)
    
    for _, row in df.iterrows():
        model_name = row['Model']
        if len(model_name) > 38:
            model_name = model_name[:35] + "..."
        
        values = [format_value(row[ds], ds in percentage_datasets) for ds in spatial_datasets]
        print(f"{model_name:<40} {values[0]:>12} {values[1]:>12} {values[2]:>12} {values[3]:>12}")
    
    # 计算并显示平均值
    print("-" * 100)
    for i, ds in enumerate(spatial_datasets):
        avg = df[ds].mean()
        if i == 0:
            print(f"{'Average':<40} {format_value(avg, ds in percentage_datasets):>12}", end="")
        else:
            print(f" {format_value(avg, ds in percentage_datasets):>12}", end="")
    print()
    
    # Hallucination Benchmarks
    print_category_header("👁️ Hallucination Benchmarks")
    print(f"\n{'Model':<40} {'HallusionBench':>15} {'POPE':>12}")
    print("-" * 100)
    
    for _, row in df.iterrows():
        model_name = row['Model']
        if len(model_name) > 38:
            model_name = model_name[:35] + "..."
        
        values = [format_value(row[ds], ds in percentage_datasets) for ds in hallucination_datasets]
        print(f"{model_name:<40} {values[0]:>15} {values[1]:>12}")
    
    # 计算并显示平均值
    print("-" * 100)
    for i, ds in enumerate(hallucination_datasets):
        avg = df[ds].mean()
        if i == 0:
            print(f"{'Average':<40} {format_value(avg, ds in percentage_datasets):>15}", end="")
        else:
            print(f" {format_value(avg, ds in percentage_datasets):>12}", end="")
    print()


def print_ranking(df):
    """打印各数据集的模型排名"""
    print_category_header("🏆 Model Rankings by Dataset")
    
    all_datasets = ['MathVista_MINI', 'MathVision_MINI', 'MathVerse_MINI_Vision_Only', 'WeMath',
                   '3DSRBench', 'A-OKVQA', 'SpatialEval', 'RealWorldQA',
                   'HallusionBench', 'POPE']
    
    dataset_display_names = {
        'MathVista_MINI': 'MathVista',
        'MathVision_MINI': 'MathVision',
        'MathVerse_MINI_Vision_Only': 'MathVerse',
        'WeMath': 'WeMath',
        '3DSRBench': '3DSRBench',
        'A-OKVQA': 'A-OKVQA',
        'SpatialEval': 'SpatialEval',
        'RealWorldQA': 'RealWorldQA',
        'HallusionBench': 'HallusionBench',
        'POPE': 'POPE'
    }
    
    for dataset in all_datasets:
        # 按该数据集排序（降序）
        df_sorted = df.sort_values(dataset, ascending=False)
        
        display_name = dataset_display_names.get(dataset, dataset)
        print(f"\n{display_name}:")
        print("-" * 80)
        
        for rank, (_, row) in enumerate(df_sorted.iterrows(), 1):
            if pd.notna(row[dataset]):
                model_name = row['Model']
                if len(model_name) > 50:
                    model_name = model_name[:47] + "..."
                
                value = row[dataset]
                value_str = format_value(value, dataset in ['MathVista_MINI', 'MathVision_MINI', 
                                                            'MathVerse_MINI_Vision_Only', 'WeMath', 
                                                            'HallusionBench', 'POPE'])
                
                medal = "🥇" if rank == 1 else "🥈" if rank == 2 else "🥉" if rank == 3 else f"{rank}."
                print(f"  {medal} {model_name:<50} {value_str:>15}")


def print_summary_statistics(df):
    """打印汇总统计"""
    print_category_header("📈 Summary Statistics")
    
    all_datasets = ['MathVista_MINI', 'MathVision_MINI', 'MathVerse_MINI_Vision_Only', 'WeMath',
                   '3DSRBench', 'A-OKVQA', 'SpatialEval', 'RealWorldQA',
                   'HallusionBench', 'POPE']
    
    # 计算每个模型的平均表现（按类别）
    df['Math_Avg'] = df[['MathVista_MINI', 'MathVision_MINI', 'MathVerse_MINI_Vision_Only', 'WeMath']].mean(axis=1)
    df['Spatial_Avg'] = df[['3DSRBench', 'A-OKVQA', 'SpatialEval', 'RealWorldQA']].mean(axis=1)
    df['Hallucination_Avg'] = df[['HallusionBench', 'POPE']].mean(axis=1)
    df['Overall_Avg'] = df[all_datasets].mean(axis=1)
    
    print(f"\n{'Model':<40} {'Math':>12} {'Spatial':>12} {'Hallucination':>15} {'Overall':>12}")
    print("-" * 100)
    
    for _, row in df.iterrows():
        model_name = row['Model']
        if len(model_name) > 38:
            model_name = model_name[:35] + "..."
        
        print(f"{model_name:<40} "
              f"{format_value(row['Math_Avg'], True):>12} "
              f"{format_value(row['Spatial_Avg'], False):>12} "
              f"{format_value(row['Hallucination_Avg'], True):>15} "
              f"{format_value(row['Overall_Avg'], False):>12}")
    
    # 最佳模型
    print("\n" + "="*100)
    print("🏆 Best Models by Category:")
    print("-" * 100)
    
    best_math = df.loc[df['Math_Avg'].idxmax(), 'Model']
    best_spatial = df.loc[df['Spatial_Avg'].idxmax(), 'Model']
    best_hallucination = df.loc[df['Hallucination_Avg'].idxmax(), 'Model']
    best_overall = df.loc[df['Overall_Avg'].idxmax(), 'Model']
    
    print(f"  Math Benchmarks:        {best_math}")
    print(f"  Spatial Reasoning:      {best_spatial}")
    print(f"  Hallucination:          {best_hallucination}")
    print(f"  Overall Performance:    {best_overall}")


def main():
    parser = argparse.ArgumentParser(
        description="可视化模型评估结果",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
示例:
  python visualize_results.py
  python visualize_results.py --input my_results.csv
        """
    )
    parser.add_argument("--input", "-i", default="model_results.csv",
                       help="输入CSV文件路径（默认: model_results.csv）")
    
    args = parser.parse_args()
    
    # 读取结果
    try:
        df = pd.read_csv(args.input)
    except FileNotFoundError:
        print(f"错误: 文件不存在: {args.input}")
        print(f"\n请先运行 collect_results.py 或 collect_all_results.py 生成结果文件")
        sys.exit(1)
    except Exception as e:
        print(f"错误: 无法读取文件: {e}")
        sys.exit(1)
    
    # 打印结果
    print_separator("=")
    print(f"{'':^100}")
    print(f"{'模型评估结果对比':^94}")
    print(f"{'':^100}")
    print(f"{'共 ' + str(len(df)) + ' 个模型':^94}")
    print(f"{'':^100}")
    print_separator("=")
    
    # 对比表格
    print_comparison_table(df)
    
    # 汇总统计
    print_summary_statistics(df)
    
    # 排名
    print_ranking(df)
    
    print("\n" + "="*100)
    print(f"结果文件: {args.input}")
    print("="*100 + "\n")


if __name__ == "__main__":
    main()

