#!/usr/bin/env bash
set -euo pipefail

if [[ $# -lt 1 ]]; then
  echo "usage: $0 /path/to/sft_checkpoint.pt [weight_mode] [extra run_verification_weight_grpo.py args]"
  echo "weight_mode: equal|none|optimal (default: equal)"
  echo "eval_start_mult: multiplier for eval starts (default: 5, override via EVAL_START_MULT or --eval_start_mult)"
  echo "note: pass --n_hubs/--m if the checkpoint lacks ds metadata"
  exit 1
fi

CKPT="$1"
shift

WEIGHT_MODE="equal"
if [[ $# -gt 0 && "$1" != --* ]]; then
  WEIGHT_MODE="$1"
  shift
fi
EVAL_START_MULT=500
# WEIGHT_MODE="equal"
# python run_verification_weight_grpo.py \
#   --init_ckpt "$CKPT" \
#   --weight_mode "$WEIGHT_MODE" \
#   --steps 1600 \
#   --batch_size 8 \
#   --group_size 20 \
#   --seed 42 \
#   --verify_k 2 \
#   --eval_start_mult "$EVAL_START_MULT" \
#   --lr 1e-5 \
#   --kl_coef 0.02 \
#   --temperature 1.0 \
#   --top_k 0 \
#   --amp \
#   --save_best \
#   --save_every 500 \
#   --save_dir checkpoints \
#   "$@"


WEIGHT_MODE="none"
python run_verification_weight_grpo.py \
  --init_ckpt "$CKPT" \
  --weight_mode "$WEIGHT_MODE" \
  --steps 4000 \
  --batch_size 8 \
  --group_size 40 \
  --seed 42 \
  --verify_k 4 \
  --eval_start_mult "$EVAL_START_MULT" \
  --lr 1e-5 \
  --kl_coef 0.02 \
  --temperature 1.0 \
  --top_k 0 \
  --amp \
  --save_best \
  --save_every 500 \
  --save_dir checkpoints \
  "$@"

WEIGHT_MODE="optimal"
python run_verification_weight_grpo.py \
  --init_ckpt "$CKPT" \
  --weight_mode "$WEIGHT_MODE" \
  --steps 4000 \
  --batch_size 8 \
  --group_size 40 \
  --seed 42 \
  --verify_k 4 \
  --eval_start_mult "$EVAL_START_MULT" \
  --lr 1e-5 \
  --kl_coef 0.02 \
  --temperature 1.0 \
  --top_k 0 \
  --amp \
  --save_best \
  --save_every 500 \
  --save_dir checkpoints \
  "$@"
