#!/usr/bin/env python3
"""
脚本功能：收集指定模型在多个数据集上的评估结果
使用方法：python collect_results.py <模型名称> [输出CSV路径]
"""

import os
import sys
import pandas as pd
import argparse
from pathlib import Path


# 数据集配置：按照指定顺序
DATASETS_CONFIG = [
    # Math Benchmarks
    {
        "name": "MathVista_MINI",
        "file_suffix": "_JUDGE_score.csv",
        "row_column": "Task&Skill",
        "row_value": "Overall",
        "result_column": "acc",
        "category": "Math"
    },
    {
        "name": "MathVision_MINI",
        "file_suffix": "_JUDGE_score.csv",
        "row_column": "Subject",
        "row_value": "Overall",
        "result_column": "acc",
        "category": "Math"
    },
    {
        "name": "MathVerse_MINI_Vision_Only",
        "file_suffix": "_JUDGE_score.csv",
        "row_column": "split",
        "row_value": "Vision Only",
        "result_column": "Overall",
        "category": "Math"
    },
    {
        "name": "MathVerse_MINI_Vision_Intensive",
        "file_suffix": "_JUDGE_score.csv",
        "row_column": "split",
        "row_value": "Vision Intensive",
        "result_column": "Overall",
        "category": "Math"
    },
    {
        "name": "WeMath",
        "file_suffix": "_JUDGE_score.csv",
        "row_column": "Model",
        "row_value": "",
        "result_column": "Score (Strict)",
        "category": "Math"
    },
    # Spatial Reasoning Benchmarks
    # {
    #     "name": "3DSRBench",
    #     "file_suffix": "_acc.csv",
    #     "row_column": "split",
    #     "row_value": "none",
    #     "result_column": "Overall",
    #     "category": "Spatial Reasoning"
    # },
    {
        "name": "A-OKVQA",
        "file_suffix": "_acc.csv",
        "row_column": "split",
        "row_value": "val",
        "result_column": "Overall",
        "category": "Spatial Reasoning"
    },
    # {
    #     "name": "SpatialEval",
    #     "file_suffix": "_acc.csv",
    #     "row_column": "split",
    #     "row_value": "none",
    #     "result_column": "Overall",
    #     "category": "Spatial Reasoning"
    # },
    {
        "name": "RealWorldQA",
        "file_suffix": "_acc.csv",
        "row_column": "split",
        "row_value": "none",
        "result_column": "Overall",
        "category": "Perception"
    },
    # Hallucination Benchmarks
    {
        "name": "MMStar",
        "file_suffix": "_acc.csv",
        "row_column": "split",
        "row_value": "none",
        "result_column": "Overall",
        "category": "Perception"
    },
    {
        "name": "TextVQA_VAL",
        "file_suffix": "_acc.csv",
        "row_column": "",
        "row_value": "",
        "result_column": "Overall",
        "category": "Perception"
    },
    {
        "name": "HallusionBench",
        "file_suffix": "_score.csv",
        "row_column": "split",
        "row_value": "Overall",
        "result_column": "fAcc",
        "category": "Perception"
    },
    {
        "name": "POPE",
        "file_suffix": "_score.csv",
        "row_column": "split",
        "row_value": "Overall",
        "result_column": "Overall",
        "category": "Perception"
    },
]


def format_value(value):
    """
    格式化数值：
    1. 如果数值小于1，乘以100（转换为百分比）
    2. 保留4位小数
    
    Args:
        value: 原始数值
    
    Returns:
        格式化后的数值字符串，失败则返回None
    """
    try:
        # 转换为浮点数
        num_value = float(value)
        
        # 如果小于1，乘以100
        if num_value < 1:
            num_value = num_value * 100
        
        # 保留2位小数
        formatted_value = round(num_value, 2)
        
        return formatted_value
        
    except (ValueError, TypeError):
        return None


def extract_result(csv_path, row_column, row_value, result_column):
    """
    从CSV文件中提取特定的结果值
    
    Args:
        csv_path: CSV文件路径
        row_column: 用于定位行的列名
        row_value: 用于定位行的值
        result_column: 结果所在的列名
    
    Returns:
        提取的结果值（已格式化），如果失败则返回None
    """

    try:
        df = pd.read_csv(csv_path)

        # 特殊处理，如果row_value和row_column都是空字符串，则直接从df中读取
        if row_value == "" and row_column == "":
            result = df.iloc[0][result_column]
            return result
        
        # 特殊处理：如果row_value是空字符串，需要处理NaN和空字符串两种情况
        if row_value == "":
            # 查找空字符串或NaN的行
            matched_rows = df[(df[row_column] == "") | (df[row_column].isna())]
        else:
            # 查找匹配的行
            matched_rows = df[df[row_column] == row_value]
        
        if matched_rows.empty:
            print(f"  警告: 在 {csv_path} 中未找到 {row_column}='{row_value}' 的行")
            return None
        
        # 获取结果值
        if result_column not in matched_rows.columns:
            print(f"  警告: 在 {csv_path} 中未找到列 '{result_column}'")
            return None
        
        result = matched_rows.iloc[0][result_column]
        
        # 处理百分号格式（如 WeMath 的 "21.62%"）
        if isinstance(result, str) and result.endswith('%'):
            result = result.rstrip('%')
        
        # 格式化数值
        formatted_result = format_value(result)
        
        if formatted_result is None:
            print(f"  警告: 无法格式化值 '{result}'")
            return None
        
        return formatted_result
        
    except FileNotFoundError:
        return None
    except Exception as e:
        print(f"  错误: 读取 {csv_path} 时出错: {e}")
        return None


def get_actual_file_name(model_name, dataset_name, file_suffix, outputs_dir):
    """
        Some of the model names may be different from the actual model names in the outputs directory.
        This function is used to get the actual file name from the model name.
        Args:
            model_name: the model name in the outputs directory
            dataset_name: the dataset name
            file_suffix: the file suffix
            outputs_dir: the outputs directory
        Returns:
            (relative_path, actual_file_name) or (None, None) if not found
    """
    model_dir = os.path.join(outputs_dir, model_name)
    
    # 首先在模型目录下直接查找
    for file in os.listdir(model_dir):
        file_path = os.path.join(model_dir, file)
        # 跳过目录
        if os.path.isdir(file_path):
            continue
        # 检查是否匹配
        if file.endswith(f"_{dataset_name}{file_suffix}"):
            # 检查是否是有效文件（不是失效的符号链接）
            if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
                return None, file
    
    # 如果直接查找失败，搜索时间戳子目录（如 T20251030_G5103a673）
    for item in os.listdir(model_dir):
        item_path = os.path.join(model_dir, item)
        # 只查找以T开头的目录
        if os.path.isdir(item_path) and item.startswith('T'):
            for file in os.listdir(item_path):
                if file.endswith(f"_{dataset_name}{file_suffix}"):
                    file_path = os.path.join(item_path, file)
                    if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
                        return item, file
    
    return None, None

def collect_model_results(model_name, judge_name, outputs_dir="/home/ec2-user/workspace/VLMEvalKit/outputs"):
    """
    收集指定模型在所有数据集上的结果
    
    Args:
        model_name: 模型名称（即 outputs 目录下的子文件夹名称）
        outputs_dir: outputs 目录路径
    
    Returns:
        包含所有结果的字典，key为数据集名称，value为结果值
    """
    model_dir = os.path.join(outputs_dir, model_name)
    
    if not os.path.exists(model_dir):
        print(f"错误: 模型目录不存在: {model_dir}")
        return None
    
    print(f"\n正在收集模型 '{model_name}' 的结果...")
    results = {"Model": model_name}
    
    for dataset_config in DATASETS_CONFIG:
        dataset_name = dataset_config["name"]
        file_suffix = dataset_config["file_suffix"]

        # if the placeholder "JUDGE" is in the file_suffix, replace it with the judge_name
        if "JUDGE" in file_suffix:
            file_suffix = file_suffix.replace("JUDGE", judge_name)
            # breakpoint()

        
        # 查找实际文件
        subdir, actual_file_name = get_actual_file_name(model_name, dataset_name, file_suffix, outputs_dir)
        
        if actual_file_name is None:
            print(f"  处理 {dataset_name}... ✗ (未找到文件)")
            results[dataset_name] = None
            continue

        # 构建完整路径
        if subdir:
            csv_path = os.path.join(model_dir, subdir, actual_file_name)
        else:
            csv_path = os.path.join(model_dir, actual_file_name)

        print(f"  处理 {dataset_name}...", end=" ")
        
        # 提取结果
        result = extract_result(
            csv_path,
            dataset_config["row_column"],
            dataset_config["row_value"],
            dataset_config["result_column"]
        )
        
        if result is not None:
            print(f"✓ (值: {result})")
        else:
            print("✗ (缺失)")
        
        results[dataset_name] = result

    # # Add a "AVG" column to the results, representing the average of the results, if there is none-None values, set avg to None
    # breakpoint()
    print(f"  处理 {model_name} 的平均结果...")
    if any(result is not None for result in results.values()):
        avg = sum(result for result in results.values() if result is not None and type(result) != str) / len(results.values())
        avg = round(avg, 2)
        print(f"  平均结果: {avg}")
    else:
        avg = None
        print(f"  没有有效的结果")
    results["AVG"] = avg

    return results


def save_results(results, output_csv="model_results.csv"):
    """
    将结果保存到CSV文件中
    如果文件已存在，则追加；否则创建新文件
    
    Args:
        results: 包含结果的字典
        output_csv: 输出CSV文件路径
    """
    # 准备DataFrame的列顺序
    columns = ["Model"] + [config["name"] for config in DATASETS_CONFIG] + ["AVG"]
    
    # 创建新行的DataFrame
    new_row = pd.DataFrame([results], columns=columns)
    
    # 检查文件是否存在
    if os.path.exists(output_csv):
        print(f"\n输出文件已存在，追加新行到: {output_csv}")
        # 读取现有文件
        existing_df = pd.read_csv(output_csv)
        
        # 检查模型是否已存在
        if results["Model"] in existing_df["Model"].values:
            print(f"警告: 模型 '{results['Model']}' 已存在于结果文件中，将更新该行")
            # 删除旧的行
            existing_df = existing_df[existing_df["Model"] != results["Model"]]
        
        # 追加新行
        combined_df = pd.concat([existing_df, new_row], ignore_index=True)
    else:
        print(f"\n创建新的输出文件: {output_csv}")
        combined_df = new_row
    
    # 格式化数值列，保留4位小数
    for col in combined_df.columns:
        if col != "Model":
            # 将数值列转换为float，并格式化
            combined_df[col] = pd.to_numeric(combined_df[col], errors='coerce')
            combined_df[col] = combined_df[col].apply(lambda x: f"{x:.2f}" if pd.notna(x) else "")

    # sort all the rows by the name of the model
    combined_df = combined_df.sort_values(by="Model")
    
    # 保存到CSV
    combined_df.to_csv(output_csv, index=False)
    print(f"✓ 结果已保存到 {output_csv}")
    
    return combined_df


def main():
    parser = argparse.ArgumentParser(
        description="收集指定模型在多个数据集上的评估结果",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
示例:
  python collect_results.py Qwen2.5-VL-7B-Instruct-lmdeploy
  python collect_results.py Qwen_7b_reasoning_13k --output my_results.csv
        """
    )
    parser.add_argument("model_name", help="模型名称（outputs目录下的子文件夹名称）")
    parser.add_argument("--output", "-o", default="model_results_new.csv", 
                       help="输出CSV文件路径（默认: model_results.csv）")
    parser.add_argument("--outputs-dir", default="/home/ec2-user/workspace/VLMEvalKit/outputs",
                       help="outputs目录路径")
    parser.add_argument("--judge-name", default="bedrock-claude-haiku-4.5",
                       help="judge名称")
    args = parser.parse_args()
    
    # 收集结果
    results = collect_model_results(args.model_name, args.judge_name, args.outputs_dir)
    
    if results is None:
        sys.exit(1)
    
    # 保存结果
    df = save_results(results, args.output)
    
    # 打印摘要
    print("\n" + "="*60)
    print("结果摘要:")
    print("="*60)
    print(df.to_string(index=False))
    print("="*60)


if __name__ == "__main__":
    main()

