#!/bin/sh

MODEL_FAMILY=qwen2

SFT_MODEL_PATH=path/to/OpenMath-Nemotron-1.5B

DATA_PATH=path/to/processed_OpenR1-50-0-4.jsonl

MODE=slurm

EXP_NAME=questa
TRIAL_NAME=partial-50

unset CLUSTER_SPEC_PATH
CLUSTER_SPEC_PATH=path/to/sif_cluster \
SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK=True \
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
REAL_ETCD_ADDR=path/to/etcd \
 python3 -m realhf.apps.quickstart async-ppo-math \
    new_tokens_per_chunk=1024 max_head_offpolicyness=16 \
    max_concurrent_rollouts=512 \
    mode=$MODE \
    experiment_name=$EXP_NAME \
    trial_name=$TRIAL_NAME \
    wandb.mode=online \
    exp_ctrl.total_train_epochs=100 \
    exp_ctrl.save_freq_steps=50 \
    exp_ctrl.ckpt_freq_secs=50 \
    actor.type._class=$MODEL_FAMILY \
    actor.path=$SFT_MODEL_PATH \
    critic.type._class=$MODEL_FAMILY \
    critic.type.is_critic=True \
    critic.init_critic_from_actor=True \
    critic.path=$SFT_MODEL_PATH \
    ref.type._class=$MODEL_FAMILY \
    ref.path=$SFT_MODEL_PATH \
    dataset.path=$DATA_PATH \
    dataset.max_prompt_len=4096 \
    dataset.train_bs_n_seqs=128 \
    ppo.gen.max_new_tokens=24648 \
    group_size=16 \
    ppo.gen.min_new_tokens=0 \
    ppo.disable_value=True \
    ppo.gen.temperature=1.0 \
    ppo.ppo_n_minibatches=1 \
    ppo.kl_ctl=0.0 \
    ppo.eps_clip=0.2 \
    ppo.value_eps_clip=0.2 \
    ppo.reward_output_scaling=5 \
    ppo.reward_output_bias=0.0 \
    ppo.adv_norm=True ppo.value_norm=True \
    ppo.discount=1.0 \
    ppo.recompute_logprob=True \
    ppo.use_decoupled_loss=True \
    ppo.behav_imp_weight_cap=5 \
    actor.optimizer.lr=2e-5 \
    actor.optimizer.lr_scheduler_type=constant \
    actor.optimizer.warmup_steps_proportion=0.001 \
    actor.sglang.triton_attention_num_kv_splits=16 \
    actor.sglang.mem_fraction_static=0.7 \
    actor.sglang.context_length=30720 \
    ref_inf.mb_spec.max_tokens_per_mb=30720 \
    actor_inf.mb_spec.max_tokens_per_mb=30720 \
    actor_train.mb_spec.max_tokens_per_mb=30720 \
    cluster.n_nodes=16 \
    cache_clear_freq=1 \
    success_rate_ub=0.95 \
    success_rate_lb=0.05 \
    n_nodes=8 \
    allocation_mode=sglang.d16m2p1+d4p4m2 \
    n_gpus_per_node=8 \
    recover_mode=auto \
    recover_retries=10 \
    torch_cache_mysophobia=True \
