set -x

# Environment setup
export WANDB_API_KEY=
export VLLM_ATTENTION_BACKEND=XFORMERS
export HYDRA_FULL_ERROR=1

# Log configuration
TIME=$(date +"%m%d%H%M%S")
WANDB_PROJECT=
EXPERIMENT_NAME=
mkdir -p logs/$EXPERIMENT_NAME/$TIME

# Model configuration
model_path=

# Data configuration
DATA_DIR=data/ppo

# Training parameters
reward_manager='naive'
train_batch_size=256 
train_mini_batch_size=64 
train_micro_batch_size_per_gpu=2 
max_prompt_length=12000
max_response_length=2048
total_epochs=1
lr=1e-6

# ray start --head --dashboard-host='0.0.0.0' --dashboard-port=8265 --ray-debugger-external

python3 -m verl.trainer.main_ppo \
    data.train_files="$DATA_DIR/train.parquet" \
    data.val_files="$DATA_DIR/test.parquet" \
    data.max_train_samples=10000 \
    data.train_batch_size=$train_batch_size \
    data.max_prompt_length=$max_prompt_length \
    data.max_response_length=$max_response_length \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=$model_path \
    actor_rollout_ref.actor.optim.lr=$lr \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=$train_mini_batch_size \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$train_micro_batch_size_per_gpu \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$train_micro_batch_size_per_gpu \
    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$train_micro_batch_size_per_gpu \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    reward_model.reward_manager=$reward_manager \
    critic.optim.lr=1e-5 \
    critic.model.use_remove_padding=True \
    critic.model.path=$model_path \
    critic.model.enable_gradient_checkpointing=True \
    critic.ppo_micro_batch_size_per_gpu=4 \
    critic.model.fsdp_config.param_offload=False \
    critic.model.fsdp_config.optimizer_offload=False \
    algorithm.kl_ctrl.kl_coef=0.001 \
    trainer.critic_warmup=0 \
    trainer.logger=['console','wandb'] \
    trainer.project_name=$WANDB_PROJECT \
    trainer.experiment_name=${EXPERIMENT_NAME}_$TIME \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=30 \
    trainer.test_freq=1 \
    trainer.total_epochs=$total_epochs 2>&1 | tee logs/$EXPERIMENT_NAME/$TIME/output.log



MODEL_CKPT_DIR=
python verl/model_merger.py --local_dir ${MODEL_CKPT_DIR}