#!/usr/bin/env bash
set -euo pipefail

if [[ "${1:-}" == "-h" || "${1:-}" == "--help" ]]; then
  echo "usage: $0 [weight_mode] [extra run_verification_weight_grpo_trap.py args]"
  echo "weight_mode: equal|none|optimal (default: all three)"
  echo "env: CKPT_DIR=checkpoints N_HUBS=5 M=6 BLOCK_SIZE=128 EVAL_START_MULT=500"
  exit 0
fi

WEIGHT_MODES=(equal none optimal)
if [[ $# -gt 0 && "${1:-}" != --* ]]; then
  case "$1" in
    equal|none|optimal)
      WEIGHT_MODES=("$1")
      shift
      ;;
    *)
      echo "unknown weight_mode: $1" >&2
      exit 1
      ;;
  esac
fi

EXTRA_ARGS=("$@")

CKPT_DIR="${CKPT_DIR:-checkpoints}"
N_HUBS="${N_HUBS:-5}"
M="${M:-6}"
BLOCK_SIZE="${BLOCK_SIZE:-128}"
EVAL_START_MULT="${EVAL_START_MULT:-500}"

EMB_LIST=(32)
N_LIST=(4)

for n_emb in "${EMB_LIST[@]}"; do
  for n in "${N_LIST[@]}"; do
    n_layer="$n"
    n_head="$n"
    run_tag="h${N_HUBS}_m${M}_emb${n_emb}_l${n_layer}_head${n_head}_bs${BLOCK_SIZE}"
    ckpt="${CKPT_DIR}/best_${run_tag}.pt"
    if [[ ! -f "$ckpt" ]]; then
      echo "missing ckpt: $ckpt" >&2
      continue
    fi

    for weight_mode in "${WEIGHT_MODES[@]}"; do
      seed_args=()
      if [[ "$weight_mode" == "none" ]]; then
        seed_args=(--seed 42)
      fi
      python run_verification_weight_grpo_trap.py \
        --init_ckpt "$ckpt" \
        --weight_mode "$weight_mode" \
        "${seed_args[@]}" \
        --steps 4000 \
        --batch_size 4 \
        --group_size 40 \
        --verify_k 4 \
        --eval_start_mult "$EVAL_START_MULT" \
        --lr 1e-5 \
        --kl_coef 0.02 \
        --temperature 1.0 \
        --top_k 0 \
        --amp \
        --save_best \
        --save_every 500 \
        --save_dir checkpoints \
        "${EXTRA_ARGS[@]}"
    done
  done
done
