set -x 

# python -m openrlhf.cli.serve_rm \
#     --reward_pretrain OpenRLHF/Llama-3-8b-rm-700k \
#     --port 5000 \
#     --bf16 \
#     --flash_attn \
#     --normalize_reward \
#     --max_len 8192 \
#     --batch_size 16

# # 启动 Ray head 节点
# ray stop
# ray start --head \
#     --node-ip-address=10.1.4.14 \
#     --dashboard-port=8265 \
#     --dashboard-host=0.0.0.0

# # 等待启动完成
# sleep 5

# ray job submit --address="http://10.1.4.14:8265" \
#    --runtime-env-json='{"working_dir": "/cofs04/user/maxiaoya/OpenRLHF"}' \
#    -- python3 -m openrlhf.cli.train_ppo_ray \
#    --ref_num_nodes 1 \
#    --ref_num_gpus_per_node 2 \
#    --critic_num_nodes 1 \
#    --critic_num_gpus_per_node 2 \
#    --actor_num_nodes 1 \
#    --actor_num_gpus_per_node 2 \
#    --vllm_num_engines 1 \
#    --vllm_tensor_parallel_size 8 \
#    --colocate_actor_ref \
#    --pretrain /shared/VauAI/maxiaoya/cache_files/hub/models--meta-llama--Llama-3.1-8B-Instruct \
#    --remote_rm_url http://localhost:5000/get_reward \
#    --save_path /openrlhf/checkpoint/llama3.1-8b-instruct-rlhf \
#    --micro_train_batch_size 8 \
#    --train_batch_size 128 \
#    --micro_rollout_batch_size 16 \
#    --rollout_batch_size 1024 \
#    --max_samples 100000 \
#    --max_epochs 1 \
#    --prompt_max_len 1024 \
#    --generate_max_len 1024 \
#    --zero_stage 3 \
#    --bf16 \
#    --actor_learning_rate 5e-7 \
#    --critic_learning_rate 9e-6 \
#    --init_kl_coef 0.01 \
#    --prompt_data HuggingFaceH4/ultrafeedback_binarized \
#    --input_key prompt \
#    --apply_chat_template \
#    --normalize_reward \
#    --packing_samples \
#    --adam_offload \
#    --flash_attn \
#    --gradient_checkpointing \
#    --use_wandb {wandb_token} \
#    --prompt_split train_prefs \
#    --eval_split test_prefs


ray stop
export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1
ray start --head --node-ip-address 0.0.0.0 --num-gpus 7 --dashboard-host=0.0.0.0

RAY_ADDRESS='http://127.0.0.1:8265' ray job submit \
   --working-dir /path \
   --runtime-env-json='{"setup_commands": ["pip install openrlhf[vllm]"]}' \
   -- python -m openrlhf.cli.train_ppo_ray \
   --ref_num_nodes 1 \
   --ref_num_gpus_per_node 2 \
   --critic_num_nodes 1 \
   --critic_num_gpus_per_node 2 \
   --actor_num_nodes 1 \
   --actor_num_gpus_per_node 2 \
   --vllm_num_engines 1 \
   --vllm_tensor_parallel_size 1 \
   --colocate_actor_ref \
   --pretrain  /openrlhf/examples/checkpoint/llama3-8b-rlhf \
   --remote_rm_url http://localhost:5000/get_reward \
   --save_path /openrlhf/examples/checkpoint/llama3-8b-rlhf \
   --micro_train_batch_size 2 \
   --train_batch_size 8 \
   --micro_rollout_batch_size 2 \
   --rollout_batch_size 1024 \
   --max_samples 100000 \
   --num_episodes 8 \
   --max_epochs 1 \
   --prompt_max_len 1024 \
   --generate_max_len 1024 \
   --zero_stage 2 \
   --bf16 \
   --actor_learning_rate 5e-7 \
   --critic_learning_rate 9e-6 \
   --init_kl_coef 0.01 \
   --prompt_data  \
   --input_key prompt \
   --chosen_label_key chosen \
   --reject_label_key reject \
   --apply_chat_template \
   --normalize_reward \
   --packing_samples \
   --adam_offload \
   --flash_attn \
   --gradient_checkpointing \
   --prompt_split train_prefs \
   --eval_split test_prefs \
   --save_steps 20 \
   --ckpt_path /openrlhf/examples/checkpoint/llama3-8b-rlhf \
   --eval_steps 20 \