#!/bin/bash
# OMEGA Transformative + Compositional pass@32 sweep.
# Usage: bash scripts/evals/eval_omega_pass32.sh <which> <task_suffix>
#   which: base | sft_v2_ep5 | gspo_v2_s20 | novelty_s15 | both | all
#   task_suffix: transformative | compositional | all   (default: transformative)

set -e
PROJECT_DIR="${PROJECT_DIR:-$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)}"
cd "$PROJECT_DIR"
source "${VLLM_VENV_PATH:-$HOME/verl-vllm012}/bin/activate"

export WANDB_CONSOLE=off
export PYTHONUNBUFFERED=1
export VLLM_USE_TRTLLM_ATTENTION=0
export VLLM_ATTENTION_BACKEND=FLASH_ATTN
export VLLM_USE_V1=1

WHICH="${1:-both}"
TASK_SUFFIX="${2:-transformative}"

CUSTOM_TASKS_PATH="evaluate/custom_tasks"

run_eval () {
    local label="$1"
    local model_path="$2"
    local task_suffix="$3"
    local task_name="omega_${task_suffix}_pass32"
    local out_subdir="results/omega_${task_suffix}_pass32/${label}"
    local log_file="${PROJECT_DIR}/logs/omega_${task_suffix}_pass32_${label}.log"

    mkdir -p "${out_subdir}"

    local model_args="pretrained=${model_path}"
    model_args="${model_args},tensor_parallel_size=1"
    model_args="${model_args},data_parallel_size=4"
    model_args="${model_args},gpu_memory_utilization=0.85"
    model_args="${model_args},max_model_len=26000"

    echo "=========================================="
    echo "OMEGA ${task_suffix} pass@32 — ${label}"
    echo "  model:  ${model_path}"
    echo "  task:   ${task_name}"
    echo "  output: ${out_subdir}"
    echo "  log:    ${log_file}"
    echo "=========================================="

    python scripts/evals/lm_eval_dp_diverse.py \
        --model vllm \
        --model_args "${model_args}" \
        --include_path "${CUSTOM_TASKS_PATH}" \
        --tasks "${task_name}" \
        --batch_size auto \
        --apply_chat_template \
        --seed 42 \
        --output_path "${out_subdir}" \
        --log_samples \
        > "${log_file}" 2>&1

    rc=$?
    if [ ${rc} -eq 0 ]; then
        echo "[DONE] ${label}/${task_suffix}"
        python scripts/evals/compute_pass_at_k.py "${out_subdir}" \
            --k_values 1,2,4,8,16,32 \
            --cons_k 8,32 \
            --json_output "results/omega_${task_suffix}_pass32/${label}_pass_at_k.json" \
            --workers 8 || echo "[WARN] scoring failed for ${label}/${task_suffix}"
    else
        echo "[FAIL] ${label}/${task_suffix} exit ${rc} — see ${log_file}"
        return ${rc}
    fi
}

run_for_label () {
    local label="$1"
    local model_path="$2"
    if [ "${TASK_SUFFIX}" = "all" ]; then
        for s in transformative compositional explorative; do
            run_eval "${label}" "${model_path}" "${s}"
        done
    else
        run_eval "${label}" "${model_path}" "${TASK_SUFFIX}"
    fi
}

case "${WHICH}" in
    base|both|all)
        run_for_label "base" "allenai/OLMo-3-7B-Instruct-SFT"
        ;;
esac
case "${WHICH}" in
    novelty_s15|both|all)
        run_for_label "novelty_s15" "checkpoints/olmo3-puzzle-grpo/novelty_production_gspo_topk100_a01/merged_step_15"
        ;;
esac
case "${WHICH}" in
    sft_v2_ep5|all)
        run_for_label "sft_v2_ep5" "checkpoints/olmo3_7b_multi_puzzle_dsr_v2/merged_ep5_fp32"
        ;;
esac
case "${WHICH}" in
    gspo_v2_s20|all)
        run_for_label "gspo_v2_s20" "checkpoints/olmo3-puzzle-grpo/multi_puzzle_gspo_olmo3_v2_sft_v2/merged_step_20"
        ;;
esac

echo "=========================================="
echo "OMEGA eval complete. JSONLs at results/omega_*_pass32/"
echo "=========================================="
