HOME="./"
export PYTHONPATH=$PYTHONPATH:on-policy-cold-start


MODEL="Llama-3.2-3B-Instruct"
RUN_NAME="Llama-3.2-3B-Instruct_ppo_sft_sciworld"
PROMPT_DATA="data/sciworld_sft.json"


deepspeed --include localhost:0,1,2,3 --module openrlhf.cli.train_sft_ppo \
    --pretrain $MODEL \
    --save_path $HOME/checkpoints/$RUN_NAME \
    --micro_train_batch_size 8 \
    --train_batch_size 32 \
    --micro_rollout_batch_size 8 \
    --rollout_batch_size 256 \
    --eps_clip 0.5 \
    --temperature 0.6 \
    --buffer_limit 0 \
    --n_samples_per_prompt 1 \
    --max_samples 200000 \
    --max_epochs 3 \
    --num_episodes 3 \
    --prompt_max_len 4000 \
    --generate_max_len 4000 \
    --zero_stage 3 \
    --bf16 \
    --actor_learning_rate 2e-5 \
    --critic_learning_rate 9e-6 \
    --init_kl_coef 0.00 \
    --prompt_data  $PROMPT_DATA \
    --input_key conversations \
    --normalize_reward \
    --flash_attn \
    --adam_offload \
    --gradient_checkpointing \
    --save_steps 200 \
    --wandb_run_name $RUN_NAME \
    --ckpt_path $HOME/checkpoints/$RUN_NAME  \
    --max_ckpt_num 20000 \
    --use_multi_turn