resources:
  infra: k8s
  accelerators: H100:1 
  memory: 128+
  image_id: docker:verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.0-fa2.7.4
  ports: 8265

num_nodes: 2

secrets:
  WANDB_API_KEY: 

setup: |  
  rm -rf verl
  git clone https://github.com/volcengine/verl.git
  cd verl
  pip3 install -v -e .[vllm]
  pip3 install flashinfer-python
  # Download GSM8K dataset - alternative approach
  echo "Downloading GSM8K dataset..."
  mkdir -p ~/data/gsm8k
  # Check if the script exists and use absolute path
  if [ -f "$(pwd)/examples/data_preprocess/gsm8k.py" ]; then
    python3 "$(pwd)/examples/data_preprocess/gsm8k.py" --local_dir ~/data/gsm8k
  else
    echo "Warning: gsm8k.py script not found, skipping dataset download"
    # You might want to download the dataset manually or use a different approach
  fi
  echo "GSM8K dataset download completed"

run: |
  # Get the Head node's IP and total number of nodes
  HEAD_IP=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
  NUM_NODES=$SKYPILOT_NUM_NODES
  
  # login wandb
  # python3 -c "import wandb; wandb.login(relogin=True, key='$WANDB_API_KEY')"

  if [ "$SKYPILOT_NODE_RANK" == "0" ]; then
    # Head node starts Ray Head
    echo "Starting Ray head node..."
    ps aux | grep ray | grep 6379 &> /dev/null ||  ray start --head --disable-usage-stats \
          --port=6379 \
          --dashboard-host=0.0.0.0 \
          --dashboard-port=8265

    # Wait for all worker nodes to join the cluster with better checking
    echo "Waiting for all nodes to join Ray cluster..."
    retry_count=0
    max_retries=30
    while [ $retry_count -lt $max_retries ]; do
      connected_nodes=$(ray status 2>/dev/null | grep -c "node_" || echo "0")
      echo "Connected nodes: $connected_nodes/$NUM_NODES (attempt $((retry_count+1))/$max_retries)"
      
      if [ "$connected_nodes" -ge "$NUM_NODES" ]; then
        echo "All nodes connected to Ray cluster"
        break
      fi
      
      retry_count=$((retry_count+1))
      sleep 10
    done

    if [ $retry_count -eq $max_retries ]; then
      echo "WARNING: Not all nodes connected to Ray cluster after $max_retries attempts"
      echo "Current Ray status:"
      ray status
    fi

    python3 -m verl.trainer.main_ppo \
     data.train_files=$HOME/data/gsm8k/train.parquet \
     data.val_files=$HOME/data/gsm8k/test.parquet \
     data.train_batch_size=256 \
     data.max_prompt_length=512 \
     data.max_response_length=256 \
     actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
     actor_rollout_ref.actor.optim.lr=1e-6 \
     actor_rollout_ref.actor.ppo_mini_batch_size=64 \
     actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
     actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
     actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
     actor_rollout_ref.rollout.name=vllm \
     actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
     actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
     critic.optim.lr=1e-5 \
     critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \
     critic.ppo_micro_batch_size_per_gpu=4 \
     algorithm.kl_ctrl.kl_coef=0.001 \
     trainer.logger=[console,wandb] \
     trainer.val_before_train=False \
     trainer.default_hdfs_dir=null \
     trainer.n_gpus_per_node=1 \
     trainer.nnodes=2 \
     trainer.save_freq=20 \
     trainer.test_freq=20 \
     trainer.total_epochs=2 \
     trainer.project_name=verl_examples \
     trainer.experiment_name=experiment_name_gsm8k

  else
    # Wait for Ray Head to start
    sleep 15
    # Worker node starts Ray Worker
    echo "Starting Ray worker node..."
    ps aux | grep ray | grep $HEAD_IP:6379 &> /dev/null || ray start --address $HEAD_IP:6379 --disable-usage-stats
    sleep 10
  fi

  echo "Node setup and Ray start script finished for rank $SKYPILOT_NODE_RANK."