#!/bin/bash
# HMMT pass@32 eval — runs all 4 HMMT competitions (Feb 2024/2025 + Nov 2025 + Feb 2026)
# combined for ~123 problems, OR a single competition.
#
# Usage:
#   bash scripts/evals/eval_hmmt_pass32.sh <which_ckpt> [task_suffix]
#     which_ckpt: base | sft_v2_ep5 | gspo_v2_s20 | novelty_s15 | both | all
#     task_suffix: feb2025 | feb2026 | nov2025 | all   (default: feb2025)
#
# JSONLs land at results/hmmt_<task_suffix>_pass32/<ckpt_label>/

set -e
PROJECT_DIR="${PROJECT_DIR:-$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)}"
cd "$PROJECT_DIR"
source "${VLLM_VENV_PATH:-$HOME/verl-vllm012}/bin/activate"

export WANDB_CONSOLE=off
export PYTHONUNBUFFERED=1
export VLLM_USE_TRTLLM_ATTENTION=0
export VLLM_ATTENTION_BACKEND=FLASH_ATTN
export VLLM_USE_V1=1

WHICH="${1:-both}"
TASK_SUFFIX="${2:-feb2025}"

CUSTOM_TASKS_PATH="evaluate/custom_tasks"

run_eval () {
    local label="$1"
    local model_path="$2"
    local task_suffix="$3"
    local task_name="hmmt_${task_suffix}_pass32"
    local out_subdir="results/hmmt_${task_suffix}_pass32/${label}"
    local log_file="${PROJECT_DIR}/logs/hmmt_${task_suffix}_pass32_${label}.log"

    mkdir -p "${out_subdir}"

    local model_args="pretrained=${model_path}"
    model_args="${model_args},tensor_parallel_size=1"
    model_args="${model_args},data_parallel_size=4"
    model_args="${model_args},gpu_memory_utilization=0.85"
    model_args="${model_args},max_model_len=26000"

    echo "=========================================="
    echo "HMMT ${task_suffix} pass@32 — ${label}"
    echo "  model:  ${model_path}"
    echo "  task:   ${task_name}"
    echo "  output: ${out_subdir}"
    echo "  log:    ${log_file}"
    echo "=========================================="

    python scripts/evals/lm_eval_dp_diverse.py \
        --model vllm \
        --model_args "${model_args}" \
        --include_path "${CUSTOM_TASKS_PATH}" \
        --tasks "${task_name}" \
        --batch_size auto \
        --apply_chat_template \
        --seed 42 \
        --output_path "${out_subdir}" \
        --log_samples \
        > "${log_file}" 2>&1

    rc=$?
    if [ ${rc} -eq 0 ]; then
        echo "[DONE] ${label}/${task_suffix}"
        python scripts/evals/compute_pass_at_k.py "${out_subdir}" \
            --k_values 1,2,4,8,16,32 \
            --cons_k 8,32 \
            --json_output "results/hmmt_${task_suffix}_pass32/${label}_pass_at_k.json" \
            --workers 8 || echo "[WARN] scoring step failed for ${label}/${task_suffix}"
    else
        echo "[FAIL] ${label}/${task_suffix} exit ${rc} — see ${log_file}"
        return ${rc}
    fi
}

run_for_label () {
    local label="$1"
    local model_path="$2"
    if [ "${TASK_SUFFIX}" = "all" ]; then
        for suffix in feb2025 feb2026 nov2025; do
            run_eval "${label}" "${model_path}" "${suffix}"
        done
    else
        run_eval "${label}" "${model_path}" "${TASK_SUFFIX}"
    fi
}

case "${WHICH}" in
    base|both|all)
        run_for_label "base" "allenai/OLMo-3-7B-Instruct-SFT"
        ;;
esac

case "${WHICH}" in
    novelty_s15|both|all)
        NOVELTY_PATH="checkpoints/olmo3-puzzle-grpo/novelty_production_gspo_topk100_a01/merged_step_15"
        if [ ! -d "${NOVELTY_PATH}" ]; then
            echo "[ERR] missing merged checkpoint: ${NOVELTY_PATH}"
            exit 1
        fi
        run_for_label "novelty_s15" "${NOVELTY_PATH}"
        ;;
esac

case "${WHICH}" in
    sft_v2_ep5|all)
        run_for_label "sft_v2_ep5" "checkpoints/olmo3_7b_multi_puzzle_dsr_v2/merged_ep5_fp32"
        ;;
esac

case "${WHICH}" in
    gspo_v2_s20|all)
        run_for_label "gspo_v2_s20" "checkpoints/olmo3-puzzle-grpo/multi_puzzle_gspo_olmo3_v2_sft_v2/merged_step_20"
        ;;
esac

echo "=========================================="
echo "HMMT eval complete. JSONLs at results/hmmt_*_pass32/"
echo "=========================================="
