#!/bin/bash
# ==============================================================================
# 仅评测脚本 - 跳过生成步骤，直接评测已有的 responses
# ==============================================================================

set -e

# ==============================================================================
# >>>>>>>>>>>>>>>>>>>>>> 配置区域 - 请在这里修改 <<<<<<<<<<<<<<<<<<<<<<<<
# ==============================================================================

# 已生成的 responses 文件路径列表（parquet 格式，包含 responses 列）
GENERATION_PATHS=(
    "/path/to/generations1.parquet"   # TODO: 修改为你的生成结果路径
    "/path/to/generations2.parquet"
    # 添加更多...
)

# 要报告哪些 k 值的指标
K_VALUES="1 4 8"

# 输出目录
OUTPUT_DIR="./eval_results"

# ==============================================================================
# 数据列名配置
# ==============================================================================

# responses 列名
RESPONSE_KEY="responses"

# 数据来源列名
DATA_SOURCE_KEY="data_source"

# reward_model 列名（包含 ground_truth）
REWARD_MODEL_KEY="reward_model"

# ==============================================================================
# 高级配置
# ==============================================================================

# 自定义 reward 函数（可选）
REWARD_FN_PATH=""
REWARD_FN_NAME="compute_score"

# 是否显示详细输出
VERBOSE=false

# 并行方式: none, multiprocessing, ray
PARALLEL="multiprocessing"

# ==============================================================================
# >>>>>>>>>>>>>>>>>>>>>> 配置区域结束 <<<<<<<<<<<<<<<<<<<<<<<<
# ==============================================================================

# Working directory
SCRIPT_DIR=$(dirname $(realpath $0))
WORK_DIR=$(dirname ${SCRIPT_DIR})
cd ${WORK_DIR}
export PYTHONPATH="${PYTHONPATH}:${WORK_DIR}"

# 创建输出目录
mkdir -p "${OUTPUT_DIR}/metrics"
mkdir -p "${OUTPUT_DIR}/summary"

TIMESTAMP=$(date +%Y%m%d_%H%M%S)

echo "=============================================="
echo "Evaluation Only Script"
echo "=============================================="
echo "K values to report: ${K_VALUES}"
echo "Number of files: ${#GENERATION_PATHS[@]}"
echo "Output dir: ${OUTPUT_DIR}"
echo "=============================================="

# 检查文件是否存在
VALID_COUNT=0
for GEN_PATH in "${GENERATION_PATHS[@]}"; do
    if [ -f "$GEN_PATH" ]; then
        VALID_COUNT=$((VALID_COUNT + 1))
    else
        echo "Warning: File not found: $GEN_PATH"
    fi
done

if [ $VALID_COUNT -eq 0 ]; then
    echo "Error: No valid files found!"
    exit 1
fi

echo "Found ${VALID_COUNT} valid file(s)"
echo ""

# ==============================================================================
# 评测每个文件
# ==============================================================================

for GEN_PATH in "${GENERATION_PATHS[@]}"; do
    if [ ! -f "$GEN_PATH" ]; then
        echo "Skipping non-existent file: $GEN_PATH"
        continue
    fi
    
    FILENAME=$(basename ${GEN_PATH%.*})
    METRICS_OUTPUT="${OUTPUT_DIR}/metrics/${FILENAME}_${TIMESTAMP}.csv"
    
    echo ""
    echo "=============================================="
    echo "Evaluating: ${FILENAME}"
    echo "=============================================="
    echo "  Input: ${GEN_PATH}"
    echo "  Output: ${METRICS_OUTPUT}"
    echo "  K values: ${K_VALUES}"
    echo ""
    
    # 构建评测命令
    EVAL_CMD="python3 ${SCRIPT_DIR}/eval_pass_at_k.py"
    EVAL_CMD="${EVAL_CMD} --data_path '${GEN_PATH}'"
    EVAL_CMD="${EVAL_CMD} --k ${K_VALUES}"
    EVAL_CMD="${EVAL_CMD} --response_key ${RESPONSE_KEY}"
    EVAL_CMD="${EVAL_CMD} --data_source_key ${DATA_SOURCE_KEY}"
    EVAL_CMD="${EVAL_CMD} --reward_model_key ${REWARD_MODEL_KEY}"
    EVAL_CMD="${EVAL_CMD} --output_path '${METRICS_OUTPUT}'"
    EVAL_CMD="${EVAL_CMD} --parallel ${PARALLEL}"
    
    if [ -n "$REWARD_FN_PATH" ] && [ -f "$REWARD_FN_PATH" ]; then
        EVAL_CMD="${EVAL_CMD} --reward_fn_path '${REWARD_FN_PATH}'"
        EVAL_CMD="${EVAL_CMD} --reward_fn_name ${REWARD_FN_NAME}"
    fi
    
    if [ "$VERBOSE" = true ]; then
        EVAL_CMD="${EVAL_CMD} --verbose"
    fi
    
    eval ${EVAL_CMD}
    
    echo ""
    echo "File ${FILENAME} completed!"
    echo "  Metrics: ${METRICS_OUTPUT}"
done

# ==============================================================================
# 生成总汇总 JSON 文件
# ==============================================================================

SUMMARY_OUTPUT="${OUTPUT_DIR}/summary/eval_${TIMESTAMP}_all_summary.json"

echo ""
echo "=============================================="
echo "Generating Combined Summary JSON..."
echo "=============================================="

python3 << EOF
import json
import os
import glob
from datetime import datetime

output_dir = "${OUTPUT_DIR}"
timestamp = "${TIMESTAMP}"
k_values = [int(k) for k in "${K_VALUES}".split()]

# 查找所有 summary JSON 文件
summary_files = glob.glob(os.path.join(output_dir, "metrics", "*_summary.json"))

# 初始化合并的汇总
combined_summary = {
    "timestamp": datetime.now().isoformat(),
    "k_values": k_values,
    "datasets": {},
    "overall": {}
}

# 收集所有数据源的指标
all_metrics = {}
for k in k_values:
    all_metrics[f"avg@{k}"] = []
    all_metrics[f"best@{k}"] = []
    all_metrics[f"pass@{k}"] = []

# 读取每个 summary 文件
for summary_file in sorted(summary_files):
    try:
        with open(summary_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        # 从文件路径提取数据集名称
        dataset_name = os.path.basename(data.get("data_path", summary_file))
        
        # 记录每个数据集的信息
        combined_summary["datasets"][dataset_name] = {
            "data_path": data.get("data_path"),
            "total_problems": data.get("total_problems", 0),
            "per_data_source": data.get("per_data_source", {}),
            "overall": data.get("overall", {})
        }
        
        # 收集用于计算总体平均的指标
        for data_source, source_data in data.get("per_data_source", {}).items():
            n_problems = source_data.get("n_problems", 0)
            for k_key, metrics in source_data.get("metrics", {}).items():
                for metric_name, value in metrics.items():
                    if metric_name in all_metrics:
                        all_metrics[metric_name].extend([value] * n_problems)
    except Exception as e:
        print(f"Warning: Failed to read {summary_file}: {e}")

# 计算总体平均
for k in k_values:
    if all_metrics[f"avg@{k}"]:
        combined_summary["overall"][f"k={k}"] = {
            f"avg@{k}": sum(all_metrics[f"avg@{k}"]) / len(all_metrics[f"avg@{k}"]),
            f"best@{k}": sum(all_metrics[f"best@{k}"]) / len(all_metrics[f"best@{k}"]),
            f"pass@{k}": sum(all_metrics[f"pass@{k}"]) / len(all_metrics[f"pass@{k}"])
        }

# 保存合并的汇总
output_path = "${SUMMARY_OUTPUT}"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(combined_summary, f, indent=2, ensure_ascii=False)

print(f"Combined summary saved to: {output_path}")

# 打印汇总表格
print("\n" + "=" * 70)
print("COMBINED SUMMARY")
print("=" * 70)

for dataset_name, dataset_info in combined_summary["datasets"].items():
    print(f"\n📊 {dataset_name}")
    print("-" * 50)
    for data_source, source_data in dataset_info.get("per_data_source", {}).items():
        n = source_data.get("n_problems", 0)
        print(f"  [{data_source}] ({n} problems)")
        for k_key, metrics in source_data.get("metrics", {}).items():
            k = k_key.replace("k=", "")
            avg_val = metrics.get(f"avg@{k}", 0)
            best_val = metrics.get(f"best@{k}", 0)
            pass_val = metrics.get(f"pass@{k}", 0)
            print(f"    {k_key}: avg={avg_val:.4f}  best={best_val:.4f}  pass={pass_val:.4f}")

print("\n" + "=" * 70)
print("🎯 OVERALL (Weighted Average)")
print("-" * 50)
for k_key, metrics in combined_summary["overall"].items():
    k = k_key.replace("k=", "")
    avg_val = metrics.get(f"avg@{k}", 0)
    best_val = metrics.get(f"best@{k}", 0)
    pass_val = metrics.get(f"pass@{k}", 0)
    print(f"  {k_key}: avg={avg_val:.4f}  best={best_val:.4f}  pass={pass_val:.4f}")
print("=" * 70)
EOF

echo ""
echo "=============================================="
echo "All evaluations completed!"
echo "=============================================="
echo "Results saved to: ${OUTPUT_DIR}"
echo "  - Metrics: ${OUTPUT_DIR}/metrics/"
echo "  - Summary JSON: ${SUMMARY_OUTPUT}"
echo "=============================================="
