#!/bin/bash
# ==============================================================================
# 端到端 Pass@K 评测脚本
# ==============================================================================
# 功能：
# 1. 使用模型为数据集生成 responses（支持多卡加速）
# 2. 评测生成结果，计算 avg@k, best@k, pass@k
#
# 只需提供原始测试数据集，脚本会自动完成生成和评测
# ==============================================================================

set -e

# ==============================================================================
# >>>>>>>>>>>>>>>>>>>>>> 配置区域 - 请在这里修改 <<<<<<<<<<<<<<<<<<<<<<<<
# ==============================================================================

# 模型路径
MODEL_PATH="/path/to/your/model"  # TODO: 修改为你的模型路径

# 原始测试数据集路径列表（parquet 格式，包含 prompt 列）
DATA_PATHS=(
    "/path/to/test_dataset1.parquet"   # TODO: 修改为你的数据集路径
    "/path/to/test_dataset2.parquet"
    # 添加更多数据集...
)

# 每个 prompt 生成的样本数量（这决定了能计算的最大k值）
N_SAMPLES=8

# 要报告哪些 k 值的指标（可选，默认报告几个关键点）
# 注意：生成 N_SAMPLES 个样本后，理论上可以计算 k=1到N_SAMPLES 的所有指标
# K_VALUES 只是为了避免输出过多，选择性报告一些关键的 k 值
# 例如：N_SAMPLES=8 时，可以设置 K_VALUES="1 4 8" 只看这三个关键点
K_VALUES="1 4 8"

# 输出目录
OUTPUT_DIR="./eval_results"

# ==============================================================================
# GPU 和生成配置
# ==============================================================================

# GPU 数量
NGPUS=8

# 张量并行大小（用于大模型）
TP_SIZE=1

# 生成参数
TEMPERATURE=1.0
TOP_P=1.0
TOP_K=-1
MAX_PROMPT_LENGTH=1024
MAX_RESPONSE_LENGTH=4096

# 批处理大小
BATCH_SIZE=64

# GPU 显存利用率
GPU_MEMORY_UTILIZATION=0.85

# ==============================================================================
# 数据列名配置
# ==============================================================================

# prompt 列名（原始数据集中的问题列）
PROMPT_KEY="prompt"

# 数据来源列名
DATA_SOURCE_KEY="data_source"

# reward_model 列名（包含 ground_truth）
REWARD_MODEL_KEY="reward_model"

# ==============================================================================
# 高级配置（通常不需要修改）
# ==============================================================================

# 是否跳过生成步骤（如果已有生成结果）
SKIP_GENERATION=false

# 自定义 reward 函数（可选）
REWARD_FN_PATH=""
REWARD_FN_NAME="compute_score"

# 是否显示详细输出
VERBOSE=false

# ==============================================================================
# >>>>>>>>>>>>>>>>>>>>>> 配置区域结束 <<<<<<<<<<<<<<<<<<<<<<<<
# ==============================================================================

# Working directory
SCRIPT_DIR=$(dirname $(realpath $0))
WORK_DIR=$(dirname ${SCRIPT_DIR})
cd ${WORK_DIR}
export PYTHONPATH="${PYTHONPATH}:${WORK_DIR}"

# 创建输出目录
mkdir -p "${OUTPUT_DIR}/generations"
mkdir -p "${OUTPUT_DIR}/metrics"
mkdir -p "${OUTPUT_DIR}/summary"

# 获取模型名称
MODEL_NAME=$(basename ${MODEL_PATH})
TIMESTAMP=$(date +%Y%m%d_%H%M%S)

# 初始化汇总数据
declare -a COMPLETED_DATASETS
declare -a METRICS_FILES

echo "=============================================="
echo "End-to-End Pass@K Evaluation Script"
echo "=============================================="
echo "Model: ${MODEL_PATH}"
echo "Model Name: ${MODEL_NAME}"
echo "N samples per prompt: ${N_SAMPLES}"
echo "K values to report: ${K_VALUES}"
echo "GPUs: ${NGPUS}"
echo "Number of datasets: ${#DATA_PATHS[@]}"
echo "Output dir: ${OUTPUT_DIR}"
echo "Skip generation: ${SKIP_GENERATION}"
echo "=============================================="

# 验证 K_VALUES 不超过 N_SAMPLES
MAX_K=$(echo ${K_VALUES} | tr ' ' '\n' | sort -n | tail -1)
if [ ${MAX_K} -gt ${N_SAMPLES} ]; then
    echo "Error: max(K_VALUES)=${MAX_K} cannot be larger than N_SAMPLES=${N_SAMPLES}"
    echo "Tip: With N_SAMPLES=${N_SAMPLES}, you can compute metrics for k=1 to ${N_SAMPLES}"
    exit 1
fi

echo "Note: Generating ${N_SAMPLES} samples allows computing metrics for k=1 to ${N_SAMPLES}"
echo "      Currently reporting k=${K_VALUES}"
echo ""

# 检查数据集是否存在
VALID_COUNT=0
for DATA_PATH in "${DATA_PATHS[@]}"; do
    if [ -f "$DATA_PATH" ]; then
        VALID_COUNT=$((VALID_COUNT + 1))
    else
        echo "Warning: Dataset not found: $DATA_PATH"
    fi
done

if [ $VALID_COUNT -eq 0 ]; then
    echo "Error: No valid datasets found!"
    exit 1
fi

echo "Found ${VALID_COUNT} valid dataset(s)"
echo ""

# ==============================================================================
# 处理每个数据集
# ==============================================================================

for DATA_PATH in "${DATA_PATHS[@]}"; do
    if [ ! -f "$DATA_PATH" ]; then
        echo "Skipping non-existent file: $DATA_PATH"
        continue
    fi
    
    DATASET_NAME=$(basename ${DATA_PATH%.*})
    GENERATION_OUTPUT="${OUTPUT_DIR}/generations/${MODEL_NAME}_${DATASET_NAME}_n${N_SAMPLES}.parquet"
    METRICS_OUTPUT="${OUTPUT_DIR}/metrics/${MODEL_NAME}_${DATASET_NAME}_${TIMESTAMP}.csv"
    
    echo ""
    echo "=============================================="
    echo "Processing: ${DATASET_NAME}"
    echo "=============================================="
    
    # -------------------------------------------------------------------------
    # Step 1: 生成 responses
    # -------------------------------------------------------------------------
    if [ "$SKIP_GENERATION" = true ] && [ -f "$GENERATION_OUTPUT" ]; then
        echo "Skipping generation (file exists): ${GENERATION_OUTPUT}"
    else
        echo ""
        echo "[Step 1/2] Generating responses..."
        echo "  Input: ${DATA_PATH}"
        echo "  Output: ${GENERATION_OUTPUT}"
        echo "  N samples: ${N_SAMPLES}"
        echo ""
        
        python3 -m verl.trainer.main_generation \
            model.path="${MODEL_PATH}" \
            data.path="${DATA_PATH}" \
            data.prompt_key="${PROMPT_KEY}" \
            data.n_samples=${N_SAMPLES} \
            data.batch_size=${BATCH_SIZE} \
            data.output_path="${GENERATION_OUTPUT}" \
            rollout.name=vllm \
            rollout.temperature=${TEMPERATURE} \
            rollout.top_p=${TOP_P} \
            rollout.top_k=${TOP_K} \
            rollout.prompt_length=${MAX_PROMPT_LENGTH} \
            rollout.response_length=${MAX_RESPONSE_LENGTH} \
            rollout.gpu_memory_utilization=${GPU_MEMORY_UTILIZATION} \
            rollout.tensor_model_parallel_size=${TP_SIZE} \
            rollout.enable_chunked_prefill=True \
            trainer.n_gpus_per_node=${NGPUS} \
            trainer.nnodes=1
        
        echo "Generation completed: ${GENERATION_OUTPUT}"
    fi
    
    # -------------------------------------------------------------------------
    # Step 2: 评测
    # -------------------------------------------------------------------------
    echo ""
    echo "[Step 2/2] Evaluating responses..."
    echo "  Input: ${GENERATION_OUTPUT}"
    echo "  Output: ${METRICS_OUTPUT}"
    echo "  K values: ${K_VALUES}"
    echo ""
    
    # 构建评测命令
    EVAL_CMD="python3 ${SCRIPT_DIR}/eval_pass_at_k.py"
    EVAL_CMD="${EVAL_CMD} --data_path '${GENERATION_OUTPUT}'"
    EVAL_CMD="${EVAL_CMD} --k ${K_VALUES}"
    EVAL_CMD="${EVAL_CMD} --response_key responses"
    EVAL_CMD="${EVAL_CMD} --data_source_key ${DATA_SOURCE_KEY}"
    EVAL_CMD="${EVAL_CMD} --reward_model_key ${REWARD_MODEL_KEY}"
    EVAL_CMD="${EVAL_CMD} --output_path '${METRICS_OUTPUT}'"
    
    if [ -n "$REWARD_FN_PATH" ] && [ -f "$REWARD_FN_PATH" ]; then
        EVAL_CMD="${EVAL_CMD} --reward_fn_path '${REWARD_FN_PATH}'"
        EVAL_CMD="${EVAL_CMD} --reward_fn_name ${REWARD_FN_NAME}"
    fi
    
    if [ "$VERBOSE" = true ]; then
        EVAL_CMD="${EVAL_CMD} --verbose"
    fi
    
    eval ${EVAL_CMD}
    
    echo ""
    echo "Dataset ${DATASET_NAME} completed!"
    echo "  Generations: ${GENERATION_OUTPUT}"
    echo "  Metrics: ${METRICS_OUTPUT}"
    
    # 记录完成的数据集和指标文件
    COMPLETED_DATASETS+=("${DATASET_NAME}")
    METRICS_FILES+=("${METRICS_OUTPUT}")
done

# ==============================================================================
# 生成总汇总 JSON 文件
# ==============================================================================

SUMMARY_OUTPUT="${OUTPUT_DIR}/summary/${MODEL_NAME}_${TIMESTAMP}_all_summary.json"

echo ""
echo "=============================================="
echo "Generating Combined Summary JSON..."
echo "=============================================="

# 使用 Python 脚本合并所有 summary JSON 文件
python3 << EOF
import json
import os
import glob
from datetime import datetime

output_dir = "${OUTPUT_DIR}"
model_name = "${MODEL_NAME}"
timestamp = "${TIMESTAMP}"
k_values = [int(k) for k in "${K_VALUES}".split()]

# 查找所有 summary JSON 文件
summary_files = glob.glob(os.path.join(output_dir, "metrics", "*_summary.json"))

# 初始化合并的汇总
combined_summary = {
    "model": model_name,
    "timestamp": datetime.now().isoformat(),
    "k_values": k_values,
    "datasets": {},
    "overall": {}
}

# 收集所有数据源的指标
all_metrics = {}
for k in k_values:
    all_metrics[f"avg@{k}"] = []
    all_metrics[f"best@{k}"] = []
    all_metrics[f"pass@{k}"] = []

# 读取每个 summary 文件
for summary_file in summary_files:
    try:
        with open(summary_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        # 从文件路径提取数据集名称
        dataset_name = os.path.basename(data.get("data_path", summary_file))
        
        # 记录每个数据集的信息
        combined_summary["datasets"][dataset_name] = {
            "data_path": data.get("data_path"),
            "total_problems": data.get("total_problems", 0),
            "per_data_source": data.get("per_data_source", {}),
            "overall": data.get("overall", {})
        }
        
        # 收集用于计算总体平均的指标
        for data_source, source_data in data.get("per_data_source", {}).items():
            n_problems = source_data.get("n_problems", 0)
            for k_key, metrics in source_data.get("metrics", {}).items():
                for metric_name, value in metrics.items():
                    if metric_name in all_metrics:
                        # 加权添加（按问题数量）
                        all_metrics[metric_name].extend([value] * n_problems)
    except Exception as e:
        print(f"Warning: Failed to read {summary_file}: {e}")

# 计算总体平均
for k in k_values:
    if all_metrics[f"avg@{k}"]:
        combined_summary["overall"][f"k={k}"] = {
            f"avg@{k}": sum(all_metrics[f"avg@{k}"]) / len(all_metrics[f"avg@{k}"]),
            f"best@{k}": sum(all_metrics[f"best@{k}"]) / len(all_metrics[f"best@{k}"]),
            f"pass@{k}": sum(all_metrics[f"pass@{k}"]) / len(all_metrics[f"pass@{k}"])
        }

# 保存合并的汇总
output_path = "${SUMMARY_OUTPUT}"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(combined_summary, f, indent=2, ensure_ascii=False)

print(f"Combined summary saved to: {output_path}")

# 打印汇总表格
print("\n" + "=" * 70)
print("COMBINED SUMMARY")
print("=" * 70)

for dataset_name, dataset_info in combined_summary["datasets"].items():
    print(f"\n📊 {dataset_name}")
    print("-" * 50)
    for data_source, source_data in dataset_info.get("per_data_source", {}).items():
        n = source_data.get("n_problems", 0)
        print(f"  [{data_source}] ({n} problems)")
        for k_key, metrics in source_data.get("metrics", {}).items():
            k = k_key.replace("k=", "")
            avg_val = metrics.get(f"avg@{k}", 0)
            best_val = metrics.get(f"best@{k}", 0)
            pass_val = metrics.get(f"pass@{k}", 0)
            print(f"    {k_key}: avg={avg_val:.4f}  best={best_val:.4f}  pass={pass_val:.4f}")

print("\n" + "=" * 70)
print("🎯 OVERALL (Weighted Average)")
print("-" * 50)
for k_key, metrics in combined_summary["overall"].items():
    k = k_key.replace("k=", "")
    avg_val = metrics.get(f"avg@{k}", 0)
    best_val = metrics.get(f"best@{k}", 0)
    pass_val = metrics.get(f"pass@{k}", 0)
    print(f"  {k_key}: avg={avg_val:.4f}  best={best_val:.4f}  pass={pass_val:.4f}")
print("=" * 70)
EOF

echo ""
echo "=============================================="
echo "All evaluations completed!"
echo "=============================================="
echo "Results saved to: ${OUTPUT_DIR}"
echo "  - Generations: ${OUTPUT_DIR}/generations/"
echo "  - Metrics: ${OUTPUT_DIR}/metrics/"
echo "  - Summary JSON: ${SUMMARY_OUTPUT}"
echo "=============================================="
