#!/bin/bash
set -x
export WANDB_API_KEY=$(cat "<path_to_your_wandb_api_key>")
export PYTHONUNBUFFERED=1

CUDA_IDS=0,1,2,3,4,5,6,7
N_GPU=8
# CUDA_IDS=0,1,2,3
# N_GPU=4

MODEL_PATH="<path_to_your_sft_model_checkpoint_or_huggingface_model_id>"

TOTAL_EPOCHES=2 # 2 * 20k
GLOBAL_BATCH_SIZE=128 # 128
ROLLOUT_BATCH_SIZE=384 # 512
VAL_BATCH_SIZE=384 # 1024
MAX_PROMPT_LENGTH=10240 # 
MAX_RESPONSE_LENGTH=2048 # 2048
MAX_NUM_BATCHED_TOKENS=12288 


EXP_NAME="<your_output_name>"

CONGI_FILE="examples/config_custom_bf16.yaml"

TRAIN_FILE="<path_to_your_train_file>"
TEST_FILE="<path_to_your_test_file>"

FORMAT_PROMPT="EasyR1/examples/format_prompt/plain.jinja"
REWARD_FUNCTION="EasyR1/examples/reward_function/math.py:compute_score"

ADD_ROLLOUT_SUFFIX="_w_policy"

export RAY_memory_usage_threshold=0.98
CUDA_VISIBLE_DEVICES=${CUDA_IDS} python3 -m verl.trainer.main \
    config=${CONGI_FILE} \
    data.train_files=${TRAIN_FILE} \
    data.val_files=${TEST_FILE} \
    data.rollout_batch_size=${ROLLOUT_BATCH_SIZE} \
    data.format_prompt=${FORMAT_PROMPT} \
    worker.actor.model.model_path=${MODEL_PATH} \
    worker.actor.global_batch_size=${GLOBAL_BATCH_SIZE} \
    trainer.experiment_name=${EXP_NAME} \
    trainer.n_gpus_per_node=${N_GPU} \
    trainer.total_epochs=${TOTAL_EPOCHES} \
    worker.reward.reward_function=${REWARD_FUNCTION} \
    worker.actor.clip_ratio_low=0.2 \
    worker.actor.clip_ratio_high=0.28 \
    algorithm.disable_kl=true \
    algorithm.online_filtering=true \
    algorithm.filter_key=accuracy \
    algorithm.filter_low=0.01 \
    algorithm.filter_high=0.99 \
    data.max_prompt_length=${MAX_PROMPT_LENGTH} \
    data.max_response_length=${MAX_RESPONSE_LENGTH} \
    data.max_num_batched_tokens=${MAX_NUM_BATCHED_TOKENS} \
    data.add_rollout_suffix=${ADD_ROLLOUT_SUFFIX}