#!/bin/bash
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=8
#SBATCH --gres=gpu:4
#SBATCH --time=48:00:00
#SBATCH --mem=720G
#SBATCH --exclusive
#SBATCH --account=your_slurm_account
#SBATCH --licenses=walrus:1,octopus:1,narwhal:1,cat:1
#
# Train meta-llama/Llama-3.1-8B on GSM8K dataset
# Learning rate: 5e-6
# Completions per prompt: 16
#
# Training setup:
#   - Model: meta-llama/Llama-3.1-8B
#   - Dataset: GSM8K (openai/gsm8k)
#   - Batch size: 256 (2 nodes)
#   - Samples per prompt: 16
#   - Learning rate: 5e-6
#   - On-policy training (JustRL-style)
#   - 2 nodes with 4 GPUs each (8 GPUs total, colocate_all=true)
#
# Usage:
#   sbatch sbatch_llama31_8b_gsm8k.sh

# Load modules - must match setup_skyrl_env.sh
module load release/25.06 GCCcore/13.3.0
module load Python/3.12.3
module load CUDA/12.6.0
module load NCCL/2.22.3-CUDA-12.6.0

export XDG_CACHE_HOME=$WORK/.cache

# Set WORK to your workspace directory before running
# export WORK=/path/to/your/workspace
source $WORK/skyrl-venv/bin/activate

set -a; source "$HOME/SkyRL/skyrl-train/experiment-slurm-scripts/.env"; set +a

cd $HOME/SkyRL/skyrl-train

export HF_HOME="$WORK/hub"
export PYTHONFAULTHANDLER=1
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export TRITON_CACHE_DIR="$WORK/triton_cache"
# Enable only when needed:
# export NCCL_DEBUG=INFO
# export CUDA_LAUNCH_BLOCKING=1

# Login to Hugging Face
huggingface-cli login --token $HF_TOKEN

export DATA_DIR="$HOME/data/gsm8k"
export NUM_GPUS=4
export TRAIN_GPUS_PER_NODE=4  # Use all GPUs for training (colocated with inference)
export INFERENCE_ENGINES=8     # 8 engines = 2 nodes × 4 GPUs/node
export LOGGER="wandb"  # or "console"
export INFERENCE_BACKEND="vllm"

# Model configuration
export MODEL="meta-llama/Llama-3.1-8B"
export MODEL_NAME="Llama-3.1-8B"

# JustRL-style training: on-policy, no clipping, no KL penalty
# We use REINFORCE as the advantage estimator for simplicity
export ADVANTAGE_ESTIMATOR="log_reinforce"
#export ADVANTAGE_ESTIMATOR="grpo"
#export ADVANTAGE_ESTIMATOR="a_reinforce"
#export ADVANTAGE_ESTIMATOR="rloo"
export A_REINFORCE_ZERO_RHO_ADV="${A_REINFORCE_ZERO_RHO_ADV:--0.5}"  # Advantage value when rho ≈ 0 (after warmup)
export A_REINFORCE_WARMUP_STEPS="${A_REINFORCE_WARMUP_STEPS:-120}"  # Steps to use 0 before switching to configured value
export LOG_REINFORCE_ZERO_RHO_ADV="${LOG_REINFORCE_ZERO_RHO_ADV:--1.0}"  # Advantage value when rho ≈ 0 (after warmup)
export LOG_REINFORCE_WARMUP_STEPS="${LOG_REINFORCE_WARMUP_STEPS:-120}"  # Steps to use 0 before switching to configured value
export RUN_NAME_SUFFIX="justrl_gsm8k"
export UPDATE_EPOCHS_PER_BATCH=1
export LEARNING_RATE="1e-6"

# Multi-node configuration
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n1)
export MASTER_PORT=29500
export RAY_TMPDIR="${SLURM_TMPDIR:-/tmp}/ray_${SLURM_JOB_ID}"
export RAY_ADDRESS="$MASTER_ADDR:6379"

# Disable Ray dashboard completely to avoid OpenTelemetry compatibility issues
export RAY_DASHBOARD_AGENT_ENABLED=0

# Prepare the dataset (idempotent - only runs if files don't exist)
if [ ! -f "$DATA_DIR/train.parquet" ]; then
  echo "Preparing GSM8K dataset..."
  mkdir -p "$DATA_DIR"
  python $HOME/SkyRL/skyrl-train/examples/gsm8k/gsm8k_dataset.py \
    --output_dir "$DATA_DIR"
else
  echo "Dataset already prepared at $DATA_DIR"
fi

export TOTAL_NUM_GPUS=$((NUM_GPUS * SLURM_NNODES))

# Ray head node setup (runs only on first node, before srun)
# This ensures head is fully initialized before workers try to connect
if [[ $(hostname) == "$MASTER_ADDR" ]]; then
  echo "Initializing Ray head node on $MASTER_ADDR"
  ulimit -n 65536 || true
  ray stop --force >/dev/null 2>&1 || true

  # Clean up stale Ray metadata to avoid IP address conflicts
  rm -rf "$RAY_TMPDIR"/* 2>/dev/null || true
  mkdir -p "$RAY_TMPDIR"

  ray start --head \
    --port=6379 \
    --node-ip-address="$MASTER_ADDR" \
    --num-gpus=$NUM_GPUS \
    --include-dashboard=false \
    --disable-usage-stats \
    --temp-dir="$RAY_TMPDIR"

  # Wait for Ray head to be fully ready
  sleep 10

  # Test Ray connection
  echo "Testing Ray head node..."
  ray status
fi

# Wait for head node to be ready
sleep 5

# Launch task on all nodes
srun --ntasks=$SLURM_NNODES --nodes=$SLURM_NNODES --ntasks-per-node=1 --gres=gpu:4 --export=ALL bash -c '
  host=$(hostname)
  hostlist=($(scontrol show hostnames "$SLURM_JOB_NODELIST"))
  for i in "${!hostlist[@]}"; do
    if [[ "${hostlist[$i]}" == "$host" ]]; then
      node_rank=$i
      break
    fi
  done

  echo "Node $node_rank ($host) starting..."

  # Worker nodes connect to Ray head
  if [[ "$host" != "$MASTER_ADDR" ]]; then
    ulimit -n 65536 || true
    echo "Node $node_rank ($host): Connecting to Ray head at $MASTER_ADDR:6379"
    ray start \
      --address="$MASTER_ADDR:6379" \
      --num-gpus=$NUM_GPUS \
      --node-ip-address="$host" \
      --temp-dir="$RAY_TMPDIR" \
      --disable-usage-stats

    echo "Node $node_rank ($host): Connected to Ray cluster, waiting..."
    # Worker nodes wait for training completion signal
    while [ ! -f "$RAY_TMPDIR/training_complete" ]; do sleep 10; done
    ray stop --force
  fi

  # Only head node runs training
  if [[ "$host" == "$MASTER_ADDR" ]]; then
    # Wait for all workers to connect
    sleep 30

    echo "Starting training on head node"
    echo "MASTER_ADDR: $MASTER_ADDR"
    echo "MASTER_PORT: $MASTER_PORT"
    echo "NUM_GPUS: $NUM_GPUS"
    echo "TOTAL_GPUS: $TOTAL_NUM_GPUS"
    nvidia-smi

    # Verify Ray cluster resources before training
    echo "Checking Ray cluster resources..."
    ray status

    # GSM8K has ~7.5k training examples
    # With batch size 256: ~29 batches per epoch
    export TOTAL_EPOCHS=10

    # Run training with JustRL configuration
    # Note: Base models are automatically patched with a plain chat template in main_base.py
    echo "Using advantage estimator: $ADVANTAGE_ESTIMATOR"
    python -m skyrl_train.entrypoints.main_base \
      data.train_data="['"'"'$DATA_DIR/train.parquet'"'"']" \
      data.val_data="['"'"'$DATA_DIR/validation.parquet'"'"']" \
      trainer.algorithm.advantage_estimator="$ADVANTAGE_ESTIMATOR" \
      trainer.policy.model.path="$MODEL" \
      generator.model_dtype=bfloat16 \
      generator.use_conversation_multi_turn=false \
      trainer.placement.colocate_all=true \
      trainer.strategy=fsdp2 \
      trainer.placement.policy_num_nodes=$SLURM_NNODES \
      trainer.placement.policy_num_gpus_per_node=$TRAIN_GPUS_PER_NODE \
      trainer.placement.ref_num_nodes=$SLURM_NNODES \
      trainer.placement.ref_num_gpus_per_node=$TRAIN_GPUS_PER_NODE \
      trainer.placement.critic_num_nodes=$SLURM_NNODES \
      trainer.placement.critic_num_gpus_per_node=$TRAIN_GPUS_PER_NODE \
      generator.num_inference_engines=$INFERENCE_ENGINES \
      generator.inference_engine_tensor_parallel_size=1 \
      generator.inference_engine_pg_strategy=SPREAD \
      trainer.epochs=$TOTAL_EPOCHS \
      trainer.eval_batch_size=128 \
      trainer.eval_before_train=true \
      trainer.eval_interval=5 \
      trainer.update_epochs_per_batch=$UPDATE_EPOCHS_PER_BATCH \
      trainer.train_batch_size=128 \
      trainer.policy_mini_batch_size=32 \
      trainer.micro_forward_batch_size_per_gpu=4 \
      trainer.micro_train_batch_size_per_gpu=4 \
      trainer.ckpt_interval=50 \
      trainer.max_prompt_length=1024 \
      generator.sampling_params.max_generate_length=2048 \
      trainer.policy.optimizer_config.lr=$LEARNING_RATE \
      trainer.algorithm.use_kl_loss=false \
      trainer.algorithm.eps_clip_low=0.0 \
      trainer.algorithm.eps_clip_high=0.0 \
      +trainer.algorithm.a_reinforce_zero_rho_advantage=$A_REINFORCE_ZERO_RHO_ADV \
      +trainer.algorithm.a_reinforce_warmup_steps=$A_REINFORCE_WARMUP_STEPS \
      +trainer.algorithm.log_reinforce_zero_rho_advantage=$LOG_REINFORCE_ZERO_RHO_ADV \
      +trainer.algorithm.log_reinforce_warmup_steps=$LOG_REINFORCE_WARMUP_STEPS \
      generator.backend=$INFERENCE_BACKEND \
      generator.run_engines_locally=true \
      generator.weight_sync_backend=nccl \
      generator.async_engine=true \
      generator.batched=true \
      environment.env_class=gsm8k \
      generator.n_samples_per_prompt=16 \
      generator.gpu_memory_utilization=0.6 \
      trainer.logger="$LOGGER" \
      trainer.project_name="justrl_gsm8k" \
      trainer.run_name="justrl_gsm8k_${MODEL_NAME}_${ADVANTAGE_ESTIMATOR}_${LEARNING_RATE}_n16_ar015" \
      trainer.resume_mode=null \
      trainer.ckpt_path="$WORK/ckpts/justrl_gsm8k_llama31_8b_ckpt"

    # Signal workers to shut down
    echo "Training complete, shutting down cluster..."
    touch "$RAY_TMPDIR/training_complete"
    ray stop --force
  fi
'
