#!/usr/bin/env bash
set -euo pipefail

if [[ $# -lt 1 ]]; then
  echo "usage: $0 /path/to/sft_checkpoint.pt [extra verification_reinforce_grpo.py args]"
  echo "eval_start_mult: multiplier for eval starts (default: 500, override via EVAL_START_MULT or --eval_start_mult)"
  echo "note: pass --n_hubs/--m if the checkpoint lacks ds metadata"
  exit 1
fi

CKPT="$1"
shift
EVAL_START_MULT="${EVAL_START_MULT:-500}"

python verification_reinforce_grpo.py \
  --init_ckpt "$CKPT" \
  --steps 1600 \
  --batch_size 8 \
  --group_size 20 \
  --seed 42 \
  --verify_k 2 \
  --eval_start_mult "$EVAL_START_MULT" \
  --lr 1e-5 \
  --kl_coef 0.02 \
  --temperature 1.0 \
  --top_k 0 \
  --amp \
  --save_best \
  --save_every 500 \
  --save_dir checkpoints \
  "$@"
