### Required Parameters ###
actor_pretrain_path=/path/to/actor_model  # Path to the actor model pre-trained weights
reward_pretrain_path=/path/to/reward_model # Path to the reward model pre-trained weights
reward_remote_url=x.x.x.x:30000 # Remote URL for the reward model
prompt_data_path=/path/to/prompt_data # Path to the prompt(training) data
total_sample_num=1252429 # Total number of samples in the training data

### Hyper Parameters ###
train_batch_size=1024
rollout_batch_size=1024
lr=1e-6
epoch=1
data_type=Mixed-1252K
policy_model_name=InternLM3-8B-Instruct
rm_name=RM_POEM

### Start Running ###
name="final-ppo-ray-policy_${policy_model_name}-${rm_name}_data_${data_type}_bsz_${train_batch_size}_lr_${lr}_epoch_${epoch}"

save_steps=$(( (total_sample_num / rollout_batch_size) / 15 ))

cd /path/to/OpenRLHF
TARGET_FILE="/path/to/OpenRLHF/addr/addr_${name}.txt"
RANK=${RANK:-0}
MASTER_PORT=6379
MASTER_ADDR=${MASTER_ADDR}
echo "MASTER_ADDR: $MASTER_ADDR"



echo "Rank $RANK is running on $MASTER_ADDR"
if [ "$RANK" -eq 0 ]; then 
    echo "Starting head node (RANK=${RANK}) on port $MASTER_PORT..."
    
    MASTER_ADDR=${MASTER_ADDR}
    echo "$MASTER_ADDR" > "$TARGET_FILE"

    ray start --head --num-gpus 8 --block &
    sleep 60
    
    echo "Executing main program on head node..."
    # TODO:     # --colocate_critic_reward \
    ray job submit --address="http://127.0.0.1:8265"  \
    --runtime-env-json='{"working_dir": "/path/to/OpenRLHF"}' \
    -- python -m openrlhf.cli.train_ppo_ray \
    --ref_num_nodes 0 \
    --ref_num_gpus_per_node 8 \
    --critic_num_nodes 1 \
    --critic_num_gpus_per_node 8 \
    --actor_num_nodes 1 \
    --actor_num_gpus_per_node 8 \
    --vllm_num_engines 16 \
    --vllm_tensor_parallel_size 1 \
    --vllm_sync_backend nccl \
    --colocate_actor_ref \
    --pretrain $actor_pretrain_path \
    --remote_rm_url $reward_remote_url \
    --reward_pretrain $reward_pretrain_path \
    --save_path /path/to/OpenRLHF/ckpts/${name} \
    --ckpt_path /path/to/OpenRLHF/ckpts/${name} \
    --micro_train_batch_size 2 \
    --train_batch_size $train_batch_size \
    --micro_rollout_batch_size 32 \
    --rollout_batch_size $rollout_batch_size \
    --num_episodes $epoch \
    --prompt_max_len 4096 \
    --generate_max_len 4096 \
    --save_steps $save_steps \
    --max_ckpt_num $save_steps \
    --save_hf_ckpt \
    --zero_stage 1 \
    --bf16 \
    --lambd 1 \
    --actor_learning_rate $lr \
    --critic_learning_rate 1e-5 \
    --actor_min_learning_rate $lr \
    --critic_min_learning_rate 1e-6 \
    --lr_warmup_ratio 0.03 \
    --critic_pretrain $actor_pretrain_path \
    --init_kl_coef 0 \
    --prompt_data json@${prompt_data_path} \
    --input_key message_data \
    --label_key ref_message_data \
    --ref_mode \
    --reward_mean -10.0 \
    --reward_std 10.0 \
    --normalize_reward \
    --packing_samples \
    --overlap_comm \
    --flash_attn \
    --gradient_checkpointing \
    --apply_chat_template \
    --load_checkpoint \
    --use_tensorboard /path/to/OpenRLHF/logs/${name}
    
else 
    sleep 30
    MASTER_ADDR=$(cat "$TARGET_FILE")

    echo "Starting worker node (RANK=${RANK}), connecting to ${MASTER_ADDR}:${MASTER_PORT}..."
    ray start --address ${MASTER_ADDR}:${MASTER_PORT}  --num-gpus 8 --block &
    
    sleep 120
    while true; do
        status=$(ray status 2>&1)

        if echo "$status" | grep -q "Active:"; then
            echo "Active nodes found. Sleeping for 10 min..."
            sleep 600
        else
            echo "No active nodes found. Exiting..."
            exit 0
        fi
    done

fi
