resources:
  infra: k8s
  accelerators: H100:1 
  memory: 128+
  image_id: docker:verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.0-fa2.7.4
  ports: 8265

num_nodes: 2

secrets:
  WANDB_API_KEY: 

setup: |  
  rm -rf verl
  git clone https://github.com/volcengine/verl.git
  cd verl
  pip3 install -v -e .[vllm]
  pip3 install flashinfer-python
  echo "Downloading Math dataset..."
  mkdir -p ~/data/math
  python3 "$(pwd)/examples/data_preprocess/math_dataset.py" --local_dir ~/data/math
  echo "Math dataset download completed"

run: |
  HEAD_IP=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
  NUM_NODES=$SKYPILOT_NUM_NODES
  NUM_GPUS_PER_NODE=$SKYPILOT_NUM_GPUS_PER_NODE
  
  if [ "$SKYPILOT_NODE_RANK" == "0" ]; then
    echo "Starting Ray head node..."
    ps aux | grep ray | grep 6379 &> /dev/null ||  ray start --head --disable-usage-stats \
          --port=6379 \
          --dashboard-host=0.0.0.0 \
          --dashboard-port=8265

    # Wait for all worker nodes to join
    retry_count=0
    max_retries=30
    while [ $retry_count -lt $max_retries ]; do
      connected_nodes=$(ray status 2>/dev/null | grep -c "node_" || echo "0")
      echo "Connected nodes: $connected_nodes/$NUM_NODES (attempt $((retry_count+1))/$max_retries)"
      
      if [ "$connected_nodes" -ge "$NUM_NODES" ]; then
        echo "All nodes connected to Ray cluster"
        break
      fi
      
      retry_count=$((retry_count+1))
      sleep 10
    done

    python3 -m verl.trainer.main_ppo \
     algorithm.adv_estimator=grpo \
     data.train_files=$HOME/data/math/train.parquet \
     data.val_files=$HOME/data/math/test.parquet \
     data.train_batch_size=32 \
     data.max_prompt_length=256 \
     data.max_response_length=256 \
     data.filter_overlong_prompts=True \
     data.truncation='error' \
     actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \
     actor_rollout_ref.actor.optim.lr=1e-6 \
     actor_rollout_ref.model.use_remove_padding=True \
     actor_rollout_ref.actor.ppo_mini_batch_size=16 \
     actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
     actor_rollout_ref.actor.ppo_epochs=1 \
     actor_rollout_ref.actor.use_kl_loss=False \
     actor_rollout_ref.actor.entropy_coeff=0 \
     actor_rollout_ref.model.enable_gradient_checkpointing=True \
     actor_rollout_ref.actor.fsdp_config.param_offload=True \
     actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
     actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
     actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
     actor_rollout_ref.rollout.name=vllm \
     actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
     actor_rollout_ref.rollout.n=1 \
     actor_rollout_ref.rollout.enable_chunked_prefill=True \
     actor_rollout_ref.rollout.max_num_batched_tokens=2048 \
     actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
     actor_rollout_ref.ref.fsdp_config.param_offload=True \
     algorithm.use_kl_in_reward=False \
     trainer.critic_warmup=0 \
     trainer.logger=[console,wandb] \
     trainer.project_name=verl_math_grpo_demo \
     trainer.experiment_name=qwen25_7b_grpo \
     trainer.n_gpus_per_node=$NUM_GPUS_PER_NODE \
     trainer.nnodes=$NUM_NODES \
     trainer.save_freq=-1 \
     trainer.test_freq=-1 \
     trainer.total_epochs=1

  else
    sleep 15
    echo "Starting Ray worker node..."
    ps aux | grep ray | grep $HEAD_IP:6379 &> /dev/null || ray start --address $HEAD_IP:6379 --disable-usage-stats
    sleep 10
  fi

  echo "Node setup and Ray start script finished for rank $SKYPILOT_NODE_RANK."