name: spiral-multi-node-training

image: whatcanyousee/verl:ngc-cu124-vllm0.8.4-sglang0.4.5-mcore0.12.0-te2.2
integrations:
  - integration_type: git_repo
    git_repo: anonymous/spiral
    git_branch: dev
    pip_install: .
    ssh_clone: true
command: |
  export NUM_NODES=2
  export GPUS_PER_NODE=8
  export TOTAL_GPUS=$(($NUM_NODES * $GPUS_PER_NODE))
  export BS=128
  export TP=4
  cd /workspace/spiral
  
  # System setup
  sed -i 's|mirrors.tuna.tsinghua.edu.cn|us.archive.ubuntu.com|g' /etc/apt/sources.list
  apt update
  apt install iproute2 -y
  apt install -y dnsutils
  apt install -y awscli
  pip install 'urllib3<2'
  pip install s3fs
  pip install boto3

  # Set environment variables for distributed training
  export HYDRA_FULL_ERROR=1
  INTERFACE=$(ip route | grep default | awk '{print $5}' | head -1)
  echo "Using interface: $INTERFACE"
  export GLOO_SOCKET_IFNAME=$INTERFACE
  export HF_HUB_ENABLE_HF_TRANSFER=1
  export LD_LIBRARY_PATH=$(python -c "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"):$LD_LIBRARY_PATH
  export NCCL_CUMEM_ENABLE=0
  export LP_DEBUG=1
  export LP_LOG_LEVEL=DEBUG

  # Install SPIRAL dependencies following README instructions
  # Requires Python 3.10 (already in base image)
  pip install vllm==0.8.4 && pip install oat-llm==0.2.0
  pip install -e .
  # This will install the dependencies from pyproject.toml:
  # textblob, textarena==0.6.4, latex2sympy2, tabulate, human_eval

  # Download model
  python3 -c "import transformers; transformers.pipeline(model='Qwen/Qwen3-14B-Base', device='cpu')"

  # Get distributed training parameters
  # In mcli, NODE_RANK is automatically set (0 for head node, 1+ for workers)
  # Get the head node's IP address for master_addr
  if [ "$NODE_RANK" = "0" ]; then
    # Head node - use its own IP as master address
    export MASTER_ADDR=$(hostname -I | awk '{print $1}')
    echo "Head node IP (MASTER_ADDR): $MASTER_ADDR"
  else
    # Worker node - discover head node IP through mcli environment
    # mcli sets environment variables for node discovery
    export MASTER_ADDR=$HEAD_NODE_IP  # This should be provided by mcli
    echo "Worker node connecting to head node at: $MASTER_ADDR"
  fi
  
  export MASTER_PORT=12345
  export GROUP_RANK=$NODE_RANK
  export NUM_GROUPS=$NUM_NODES

  # All nodes run the same training command with their respective ranks
  echo "Starting SPIRAL training on node $NODE_RANK of $NUM_GROUPS..."
  
  # Add conditional W&B logging (only on head node to avoid conflicts)
  if [ "$NODE_RANK" = "0" ]; then
    WB_ARGS="--use-wb --wb-run-name spiral-qwen3-14b-base-kp-4k-self-play-multi-node --wb_project spiral"
  else
    WB_ARGS=""
  fi

  python train_spiral.py \
      --env_id KuhnPoker-v1 \
      --use_llm_obs_wrapper \
      --eval_env_ids TicTacToe-v0 KuhnPoker-v1 \
      --eval_use_llm_obs_wrappers False True \
      --eval_opponent_names google/gemini-2.0-flash-lite-001 \
      --eval_split all \
      --gamma 1 \
      --gpus $TOTAL_GPUS \
      --num_gpus_per_actor $TP \
      --num_groups $NUM_GROUPS \
      --group_rank $GROUP_RANK \
      --master_addr $MASTER_ADDR \
      --master_port $MASTER_PORT \
      --gradient-checkpointing \
      --num_samples 1 \
      --dump_game_state_every 1 \
      --num_envs 1 \
      --pretrain Qwen/Qwen3-14B-Base \
      --enable_prefix_caching \
      --collocate \
      --vllm_sleep \
      --vllm_gpu_ratio 0.45 \
      --zero_stage 3 \
      --no-use_fused_lm_head \
      --learning_rate 0.000001 \
      --lr_scheduler constant \
      --lr_warmup_ratio 0 \
      --num_ppo_epochs 2 \
      --rollout_batch_size $BS \
      --rollout_batch_size_per_device $(( $BS / $TOTAL_GPUS )) \
      --pi_buffer_maxlen_per_device $(( $BS / $TOTAL_GPUS )) \
      --train_batch_size $BS \
      --train_batch_size_per_device 1 \
      --beta 0 \
      --max_model_len 12800 \
      --generate_max_length 4096 \
      --max_context_length 32768 \
      --temperature 1.0 \
      --top_p 1 \
      --eval_steps 16 \
      --save_steps -1 \
      --eval_games 16 \
      --eval_temperature 0.6 \
      --eval_top_p 0.95 \
      --eval_generate_max_length 4096 \
      --max_train 51200 \
      --max_save_num 30 \
      $WB_ARGS

compute:
  gpus: 16 # 2 nodes × 8 GPUs each
  cluster: r8z13p2
  gpu_type: h100_80gb