TASK=dapo-with-aime2425 # math, gsm8k, countdown-4, dapo-with-aime2425
ALGORITHM=grpo
MODEL_PATH=Qwen/Qwen2.5-7B-Instruct
SPARSE_DIM=128
BETA=0.01
ROLLOUTS=8
REWARD_TYPE=leverage
RANDOMIZE_SPARSE_MATRIX=True
TURN_OFF_ELLIPTICAL_IF_NONE_CORRECT=True
TURN_OFF_ELLIPTICAL_IF_SOME_CORRECT=False
TURN_OFF_ELLIPTICAL_IF_ALL_CORRECT=False
TURN_OFF_ELLIPTICAL_IF_ROLLOUT_INCORRECT=False
TURN_OFF_AT_HIGHEST_PASS_AT_K=False
TRAIN_RANDOM_SUBSET_SIZE=512
ELLIPTICAL_NORMALIZATION=none
PERSIST_COVARIANCE=False
TRAIN_VAL_N=$((2 * ${ROLLOUTS})) # always double the rollout size since we're estimating pass@k where k is the rollout size
ALPHA=1.0
TEST_FREQ=20
SAVE_FREQ=20
RESUME_MODE=disable
RESUME_FROM_PATH=''
USE_KL_LOSS=True
TURN_OFF_AT_GLOBAL_STEPS=-1
PPO_EPOCHS=1
SAVE_BEST_PASS_AT_1=False
SAVE_BEST_HARD_PASS_AT_1=False
SAVE_BEST_PASS_AT_64=False
SAVE_BEST_HARD_PASS_AT_64=False
CHECKPOINT_SAVE_CONTENTS='["model","optimizer","extra"]'
MAX_ACTOR_CKPT_TO_KEEP=null
REWARD_MODEL_ENABLE=True
ELLIPTICAL_ENABLE=True
REWARD_MANAGER=elliptical
TRAIN_BATCH_SIZE=1024 # default: 1024
PPO_MINI_BATCH_SIZE=256 # default: 256
REMOVE_PREVIOUS_OPTIM_AND_EXTRA=True

if [ ${ALGORITHM} == "dr_grpo" ]; then
    LOSS_AGG_MODE="seq-mean-token-sum-norm"
    KL_LOSS_COEF=0.0
    NORM_ADV_BY_STD_IN_GRPO=False
else
    LOSS_AGG_MODE="token-mean"
    KL_LOSS_COEF=0.0 # default: 0.001
    NORM_ADV_BY_STD_IN_GRPO=True
fi

if [ ${TURN_OFF_AT_HIGHEST_PASS_AT_K} == True ]; then
    PASS_AT_K_FREQ=5
else
    PASS_AT_K_FREQ=-1
fi

if [ ${TASK} == "dapo-with-aime2425" ]; then
    TEST_FREQ=10
    SAVE_FREQ=10
    TRAIN_BATCH_SIZE=512
    PPO_MINI_BATCH_SIZE=128
fi

for SEED in 44 45; do
    echo "Running job on ${TASK} with the following parameters:"
    echo "ALGORITHM: ${ALGORITHM}"
    echo "MODEL_PATH: ${MODEL_PATH}"
    echo "REWARD_MODEL_ENABLE: ${REWARD_MODEL_ENABLE}"
    echo "ELLIPTICAL_ENABLE: ${ELLIPTICAL_ENABLE}"
    echo "SPARSE_DIM: ${SPARSE_DIM}"
    echo "REWARD_MANAGER: ${REWARD_MANAGER}"
    echo "SEED: ${SEED}"
    echo "BETA: ${BETA}"
    echo "ROLLOUTS: ${ROLLOUTS}"
    echo "REWARD_TYPE: ${REWARD_TYPE}"
    echo "RANDOMIZE_SPARSE_MATRIX: ${RANDOMIZE_SPARSE_MATRIX}"
    echo "TURN_OFF_ELLIPTICAL_IF_NONE_CORRECT: ${TURN_OFF_ELLIPTICAL_IF_NONE_CORRECT}"
    echo "TURN_OFF_ELLIPTICAL_IF_SOME_CORRECT: ${TURN_OFF_ELLIPTICAL_IF_SOME_CORRECT}"
    echo "TURN_OFF_ELLIPTICAL_IF_ALL_CORRECT: ${TURN_OFF_ELLIPTICAL_IF_ALL_CORRECT}"
    echo "LOSS_AGG_MODE: ${LOSS_AGG_MODE}"
    echo "USE_KL_LOSS: ${USE_KL_LOSS}"
    echo "NORM_ADV_BY_STD_IN_GRPO: ${NORM_ADV_BY_STD_IN_GRPO}"
    echo "TURN_OFF_AT_HIGHEST_PASS_AT_K: ${TURN_OFF_AT_HIGHEST_PASS_AT_K}"
    echo "PASS_AT_K_FREQ: ${PASS_AT_K_FREQ}"
    echo "TRAIN_RANDOM_SUBSET_SIZE: ${TRAIN_RANDOM_SUBSET_SIZE}"
    echo "TRAIN_VAL_N: ${TRAIN_VAL_N}"
    echo "ALPHA: ${ALPHA}"
    echo "TEST_FREQ: ${TEST_FREQ}"
    echo "SAVE_FREQ: ${SAVE_FREQ}"
    echo "ELLIPTICAL_NORMALIZATION: ${ELLIPTICAL_NORMALIZATION}"
    echo "RESUME_MODE: ${RESUME_MODE}"
    echo "RESUME_FROM_PATH: ${RESUME_FROM_PATH}"
    echo "PERSIST_COVARIANCE: ${PERSIST_COVARIANCE}"
    echo "KL_LOSS_COEF: ${KL_LOSS_COEF}"
    echo "TURN_OFF_AT_GLOBAL_STEPS: ${TURN_OFF_AT_GLOBAL_STEPS}"
    echo "PPO_EPOCHS: ${PPO_EPOCHS}"
    echo "TURN_OFF_ELLIPTICAL_IF_ROLLOUT_INCORRECT: ${TURN_OFF_ELLIPTICAL_IF_ROLLOUT_INCORRECT}"
    echo "SAVE_BEST_PASS_AT_1: ${SAVE_BEST_PASS_AT_1}"
    echo "SAVE_BEST_PASS_AT_64: ${SAVE_BEST_PASS_AT_64}"
    echo "CHECKPOINT_SAVE_CONTENTS: ${CHECKPOINT_SAVE_CONTENTS}"
    echo "SAVE_BEST_HARD_PASS_AT_1: ${SAVE_BEST_HARD_PASS_AT_1}"
    echo "SAVE_BEST_HARD_PASS_AT_64: ${SAVE_BEST_HARD_PASS_AT_64}"
    echo "MAX_ACTOR_CKPT_TO_KEEP: ${MAX_ACTOR_CKPT_TO_KEEP}"
    echo "TRAIN_BATCH_SIZE: ${TRAIN_BATCH_SIZE}"
    echo "PPO_MINI_BATCH_SIZE: ${PPO_MINI_BATCH_SIZE}"
    echo "REMOVE_PREVIOUS_OPTIM_AND_EXTRA: ${REMOVE_PREVIOUS_OPTIM_AND_EXTRA}"
    sbatch --job-name=${TASK}_elliptical_seed_${SEED}_kl_${KL_LOSS_COEF}_ppo_epochs_${PPO_EPOCHS}_beta_${BETA} scripts/train_elliptical.slurm \
        ${MODEL_PATH} \
        ${REWARD_MODEL_ENABLE} \
        ${ELLIPTICAL_ENABLE} \
        ${SPARSE_DIM} \
        ${REWARD_MANAGER} \
        ${SEED} \
        ${BETA} \
        ${ROLLOUTS} \
        ${REWARD_TYPE} \
        ${RANDOMIZE_SPARSE_MATRIX} \
        ${TURN_OFF_ELLIPTICAL_IF_NONE_CORRECT} \
        ${TURN_OFF_ELLIPTICAL_IF_SOME_CORRECT} \
        ${TURN_OFF_ELLIPTICAL_IF_ALL_CORRECT} \
        ${LOSS_AGG_MODE} \
        ${USE_KL_LOSS} \
        ${NORM_ADV_BY_STD_IN_GRPO} \
        ${ALGORITHM} \
        ${TURN_OFF_AT_HIGHEST_PASS_AT_K} \
        ${PASS_AT_K_FREQ} \
        ${TRAIN_RANDOM_SUBSET_SIZE} \
        ${TRAIN_VAL_N} \
        ${ALPHA} \
        ${TEST_FREQ} \
        ${SAVE_FREQ} \
        ${ELLIPTICAL_NORMALIZATION} \
        ${RESUME_MODE} \
        "${RESUME_FROM_PATH}" \
        ${PERSIST_COVARIANCE} \
        ${KL_LOSS_COEF} \
        ${TASK} \
        ${TURN_OFF_AT_GLOBAL_STEPS} \
        ${PPO_EPOCHS} \
        ${TURN_OFF_ELLIPTICAL_IF_ROLLOUT_INCORRECT} \
        ${SAVE_BEST_PASS_AT_1} \
        ${SAVE_BEST_PASS_AT_64} \
        ${CHECKPOINT_SAVE_CONTENTS} \
        ${SAVE_BEST_HARD_PASS_AT_1} \
        ${SAVE_BEST_HARD_PASS_AT_64} \
        ${MAX_ACTOR_CKPT_TO_KEEP} \
        ${TRAIN_BATCH_SIZE} \
        ${PPO_MINI_BATCH_SIZE} \
        ${REMOVE_PREVIOUS_OPTIM_AND_EXTRA}
    echo "--------------------------------"
done