#!/bin/bash
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=8
#SBATCH --gres=gpu:4
#SBATCH --time=48:00:00
#SBATCH --mem=720G
#SBATCH --exclusive
#SBATCH --account=your_slurm_account
#SBATCH --licenses=walrus:1,octopus:1,narwhal:1,cat:1
#
# JustRL-style training on DAPO-Math-17k-Processed dataset
# Paper: https://arxiv.org/abs/2512.16649
# Dataset: open-r1/DAPO-Math-17k-Processed (deduplicated version)
#
# Training setup:
#   - Dataset: English subset (~13.1k train, 1k test) - deduplicated
#   - Batch size: 256 (2x compared to 1-node)
#   - Samples per prompt: 8
#   - Learning rate: 1e-6
#   - Steps: 750 (~14-15 epochs with 13.1k samples)
#   - On-policy training (no KL penalty, no clipping)
#   - 2 nodes with 4 GPUs each (8 GPUs total, colocate_all=true)
#
# Usage:
#   sbatch sbatch_justrl_dapo_math.sh

# Load modules - must match setup_skyrl_env.sh
module load release/25.06 GCCcore/13.3.0
module load Python/3.12.3
module load CUDA/12.6.0
module load NCCL/2.22.3-CUDA-12.6.0

export XDG_CACHE_HOME=$WORK/.cache

# Set WORK to your workspace directory before running
# export WORK=/path/to/your/workspace
source $WORK/skyrl-venv/bin/activate

set -a; source "$HOME/SkyRL/skyrl-train/experiment-slurm-scripts/.env"; set +a

cd $HOME/SkyRL/skyrl-train

export HF_HOME="$WORK/hub"
export PYTHONFAULTHANDLER=1
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export TRITON_CACHE_DIR="$WORK/triton_cache"
# Enable only when needed:
# export NCCL_DEBUG=INFO
# export CUDA_LAUNCH_BLOCKING=1

# Login to Hugging Face
huggingface-cli login --token $HF_TOKEN

export DATA_DIR="$WORK/data/dapo_math"
export NUM_GPUS=4
export TRAIN_GPUS_PER_NODE=4  # Use all GPUs for training (colocated with inference)
export INFERENCE_ENGINES=8     # 8 engines = 2 nodes × 4 GPUs/node
export LOGGER="wandb"  # or "console"
export INFERENCE_BACKEND="vllm"

# Model configuration
export MODEL="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
export MODEL_NAME="DS-R1-Distill-Qwen-1.5b"

#export MODEL="Qwen/Qwen2.5-3B"
#export MODEL_NAME="Qwen2.5-3B"

# JustRL-style training: on-policy, no clipping, no KL penalty
# We use REINFORCE as the advantage estimator for simplicity
export ADVANTAGE_ESTIMATOR="log_reinforce"
#export ADVANTAGE_ESTIMATOR="grpo"
#export ADVANTAGE_ESTIMATOR="a_reinforce"
export RUN_NAME_SUFFIX="justrl_dapo_math"
export UPDATE_EPOCHS_PER_BATCH=1

# Multi-node configuration
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n1)
export MASTER_PORT=29500
export RAY_TMPDIR="${SLURM_TMPDIR:-/tmp}/ray_${SLURM_JOB_ID}"
export RAY_ADDRESS="$MASTER_ADDR:6379"

# Disable Ray dashboard completely to avoid OpenTelemetry compatibility issues
export RAY_DASHBOARD_AGENT_ENABLED=0

# Prepare the dataset (idempotent - only runs if files don't exist)
# Note: Uses open-r1/DAPO-Math-17k-Processed (deduplicated version)
# The preparation script converts to JustRL's "\boxed{}" format for AIME environment compatibility
if [ ! -f "$DATA_DIR/train.parquet" ]; then
  echo "Preparing DAPO-Math-17k-Processed dataset (English subset)..."
  mkdir -p "$DATA_DIR"
  python $HOME/SkyRL/data/prepare_dapo_math.py \
    --output_dir "$DATA_DIR" \
    --test_size 1000 \
    --seed 42 \
    --subset en
else
  echo "Dataset already prepared at $DATA_DIR"
fi

export TOTAL_NUM_GPUS=$((NUM_GPUS * SLURM_NNODES))

# Ray head node setup (runs only on first node, before srun)
# This ensures head is fully initialized before workers try to connect
if [[ $(hostname) == "$MASTER_ADDR" ]]; then
  echo "Initializing Ray head node on $MASTER_ADDR"
  ulimit -n 65536 || true
  ray stop --force >/dev/null 2>&1 || true

  # Clean up stale Ray metadata to avoid IP address conflicts
  rm -rf "$RAY_TMPDIR"/* 2>/dev/null || true
  mkdir -p "$RAY_TMPDIR"

  ray start --head \
    --port=6379 \
    --node-ip-address="$MASTER_ADDR" \
    --num-gpus=$NUM_GPUS \
    --include-dashboard=false \
    --disable-usage-stats \
    --temp-dir="$RAY_TMPDIR"

  # Wait for Ray head to be fully ready
  sleep 10

  # Test Ray connection
  echo "Testing Ray head node..."
  ray status
fi

# Wait for head node to be ready
sleep 5

# Launch task on all nodes
srun --ntasks=$SLURM_NNODES --nodes=$SLURM_NNODES --ntasks-per-node=1 --gres=gpu:4 --export=ALL bash -c '
  host=$(hostname)
  hostlist=($(scontrol show hostnames "$SLURM_JOB_NODELIST"))
  for i in "${!hostlist[@]}"; do
    if [[ "${hostlist[$i]}" == "$host" ]]; then
      node_rank=$i
      break
    fi
  done

  echo "Node $node_rank ($host) starting..."

  # Worker nodes connect to Ray head
  if [[ "$host" != "$MASTER_ADDR" ]]; then
    ulimit -n 65536 || true
    echo "Node $node_rank ($host): Connecting to Ray head at $MASTER_ADDR:6379"
    ray start \
      --address="$MASTER_ADDR:6379" \
      --num-gpus=$NUM_GPUS \
      --node-ip-address="$host" \
      --temp-dir="$RAY_TMPDIR" \
      --disable-usage-stats

    echo "Node $node_rank ($host): Connected to Ray cluster, waiting..."
    # Worker nodes wait for training completion signal
    while [ ! -f "$RAY_TMPDIR/training_complete" ]; do sleep 10; done
    ray stop --force
  fi

  # Only head node runs training
  if [[ "$host" == "$MASTER_ADDR" ]]; then
    # Wait for all workers to connect
    sleep 30

    echo "Starting training on head node"
    echo "MASTER_ADDR: $MASTER_ADDR"
    echo "MASTER_PORT: $MASTER_PORT"
    echo "NUM_GPUS: $NUM_GPUS"
    echo "TOTAL_GPUS: $TOTAL_NUM_GPUS"
    nvidia-smi

    # Verify Ray cluster resources before training
    echo "Checking Ray cluster resources..."
    ray status

    # Calculate total steps needed
    # ~13.1k train samples (after removing 1k test) / 256 batch size = ~51 batches per epoch
    # 750 steps / 51 = ~14.7 epochs
    export TOTAL_EPOCHS=9

    # Run training with JustRL configuration
    echo "Using advantage estimator: $ADVANTAGE_ESTIMATOR"
    python -m skyrl_train.entrypoints.main_base \
      data.train_data="['"'"'$DATA_DIR/train.parquet'"'"']" \
      data.val_data="['"'"'$DATA_DIR/validation.parquet'"'"']" \
      trainer.algorithm.advantage_estimator="$ADVANTAGE_ESTIMATOR" \
      trainer.policy.model.path="$MODEL" \
      trainer.placement.colocate_all=true \
      trainer.strategy=fsdp2 \
      trainer.placement.policy_num_nodes=$SLURM_NNODES \
      trainer.placement.policy_num_gpus_per_node=$TRAIN_GPUS_PER_NODE \
      trainer.placement.ref_num_nodes=$SLURM_NNODES \
      trainer.placement.ref_num_gpus_per_node=$TRAIN_GPUS_PER_NODE \
      trainer.placement.critic_num_nodes=$SLURM_NNODES \
      trainer.placement.critic_num_gpus_per_node=$TRAIN_GPUS_PER_NODE \
      generator.num_inference_engines=$INFERENCE_ENGINES \
      generator.inference_engine_tensor_parallel_size=1 \
      generator.inference_engine_pg_strategy=SPREAD \
      trainer.epochs=$TOTAL_EPOCHS \
      trainer.eval_batch_size=256 \
      trainer.eval_before_train=true \
      trainer.eval_interval=20 \
      trainer.update_epochs_per_batch=$UPDATE_EPOCHS_PER_BATCH \
      trainer.train_batch_size=256 \
      trainer.policy_mini_batch_size=64 \
      trainer.micro_forward_batch_size_per_gpu=4 \
      trainer.micro_train_batch_size_per_gpu=4 \
      trainer.ckpt_interval=50 \
      trainer.max_prompt_length=1024 \
      generator.sampling_params.max_generate_length=7168 \
      trainer.policy.optimizer_config.lr=1.0e-6 \
      trainer.algorithm.use_kl_loss=false \
      trainer.algorithm.eps_clip_low=0.0 \
      trainer.algorithm.eps_clip_high=0.0 \
      generator.backend=$INFERENCE_BACKEND \
      generator.run_engines_locally=true \
      generator.weight_sync_backend=nccl \
      generator.async_engine=true \
      generator.batched=true \
      environment.env_class=aime \
      generator.n_samples_per_prompt=8 \
      generator.gpu_memory_utilization=0.6 \
      trainer.logger="$LOGGER" \
      trainer.project_name="justrl_dapo_math" \
      trainer.run_name="justrl_dapo_math_${MODEL_NAME}_${ADVANTAGE_ESTIMATOR}_rev" \
      trainer.resume_mode=null \
      trainer.ckpt_path="$WORK/ckpts/justrl_dapo_math_ckpt"

    # Signal workers to shut down
    echo "Training complete, shutting down cluster..."
    touch "$RAY_TMPDIR/training_complete"
    ray stop --force
  fi
'
