#!/bin/bash
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=8
#SBATCH --gres=gpu:4
#SBATCH --time=24:00:00
#SBATCH --mem=720G
#SBATCH --exclusive
#SBATCH --account=your_slurm_account
#SBATCH --licenses=walrus:1,octopus:1,narwhal:1,cat:1
#
# Usage:
#   Default (log_reinforce): sbatch sbatch_capella_competition_math.sh
#   Use GRPO:                sbatch --export=ALL,ADVANTAGE_ESTIMATOR=grpo sbatch_capella_competition_math.sh
#   Use log-REINFORCE:       sbatch --export=ALL,ADVANTAGE_ESTIMATOR=log_reinforce sbatch_capella_competition_math.sh
#   Use rloo:                sbatch --export=ALL,ADVANTAGE_ESTIMATOR=rloo sbatch_capella_competition_math.sh
#  Use cancel-reinforce:    sbatch --export=ALL,ADVANTAGE_ESTIMATOR=cancel_reinforce sbatch_capella_competition_math.sh
# rho_hat_logs in $HOME/exports/... 
# Load modules - must match setup_skyrl_env.sh
module load release/25.06 GCCcore/13.3.0
module load Python/3.12.3
module load CUDA/12.6.0
module load NCCL/2.22.3-CUDA-12.6.0

export XDG_CACHE_HOME=$WORK/.cache

# Set WORK to your workspace directory before running
# export WORK=/path/to/your/workspace
source $WORK/skyrl-venv/bin/activate

set -a; source "$HOME/SkyRL/skyrl-train/experiment-slurm-scripts/.env"; set +a

cd $HOME/SkyRL/skyrl-train

export HF_HOME="$WORK/hub"
export PYTHONFAULTHANDLER=1
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export TRITON_CACHE_DIR="$WORK/triton_cache"
# Enable only when needed:
# export NCCL_DEBUG=INFO
# export CUDA_LAUNCH_BLOCKING=1

# Login to Hugging Face
huggingface-cli login --token $HF_TOKEN

export DATA_DIR="$WORK/data/competition_math_hendrycks"
export NUM_GPUS=4
export LOGGER="wandb"  # or "console"
export INFERENCE_BACKEND="vllm"

#export MODEL="Qwen/Qwen2.5-3B-Instruct"
#export MODEL_NAME="qwen2.5-3B-instruct"
#export MODEL="Qwen/Qwen2.5-7B-Instruct"
#export MODEL_NAME="qwen2.5-7B-instruct"
#export MODEL="Qwen/Qwen2.5-1.5B-Instruct"
#export MODEL_NAME="qwen2.5-1.5B-instruct"
#export MODEL="Qwen/Qwen2.5-0.5B-Instruct"
#export MODEL_NAME="qwen2.5-0.5B-instruct"
#export MODEL="google/gemma-2-2b-it"
#export MODEL_NAME="gemma-2-2b-it"
#export MODEL="HuggingFaceTB/SmolLM2-1.7B-Instruct"
#export MODEL_NAME="smollm2-1.7b-instruct"

export MODEL="meta-llama/Llama-3.2-3B-Instruct"
export MODEL_NAME="llama-3.2-3b-instruct"

#export VLLM_USE_FLASH_ATTENTION=0
#export VLLM_ALLOW_FLASH_ATTENTION=0

# Advantage estimator configuration
# Options: "grpo" or "log_reinforce"
export ADVANTAGE_ESTIMATOR="${ADVANTAGE_ESTIMATOR:-log_reinforce}"
export RUN_NAME_SUFFIX="${ADVANTAGE_ESTIMATOR}"
export UPDATE_EPOCHS_PER_BATCH=1




srun --ntasks=1 --nodes=1 --gres=gpu:4 --export=ALL bash -c '
  host=$(hostname)
  hostlist=($(scontrol show hostnames "$SLURM_JOB_NODELIST"))
  for i in "${!hostlist[@]}"; do
    if [[ "${hostlist[$i]}" == "$host" ]]; then
      node_rank=$i
      break
    fi
  done

  # --- Start local Ray (single node) ---
  export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n1)
  export MASTER_PORT=29500
  export RAY_TMPDIR="$WORK/ray"

  # Disable Ray dashboard completely to avoid OpenTelemetry compatibility issues
  export RAY_DASHBOARD_AGENT_ENABLED=0

  ulimit -n 65536 || true
  ray stop --force >/dev/null 2>&1 || true
  ray start --head --port=6379 --node-ip-address="$MASTER_ADDR" --num-gpus=$NUM_GPUS --include-dashboard=false --disable-usage-stats

  # Wait for Ray to fully initialize
  sleep 10

  # Test Ray before proceeding (with proper shutdown)
  python - <<'\''PY'\''
import ray
import time
ray.init(address="auto")
print("GPUs:", ray.cluster_resources().get("GPU", 0))
ray.shutdown()
time.sleep(2)
PY

  # Give Ray time to stabilize after test disconnection
  sleep 5


  echo "Starting python on $host with node_rank=$node_rank"
  echo "MASTER_ADDR: $MASTER_ADDR"
  echo "MASTER_PORT: $MASTER_PORT"
  echo "NUM_GPUS: $NUM_GPUS"
  echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
  nvidia-smi

  # --- Prepare Competition MATH data once (idempotent) ---
  # Using nlile/hendrycks-MATH-benchmark with explicit train/test splits
  # Uncomment to regenerate (already done - 12000 train, 500 test)
  #mkdir -p "$DATA_DIR"
  #python $HOME/SkyRL/data/prepare_competition_math.py --output_dir "$DATA_DIR" --max_level 5

  # --- Run training on Competition MATH ---
  echo "Using advantage estimator: $ADVANTAGE_ESTIMATOR"
  python -m skyrl_train.entrypoints.main_base \
    data.train_data="['\''$DATA_DIR/train.parquet'\'']" \
    data.val_data="['\''$DATA_DIR/validation.parquet'\'']" \
    trainer.algorithm.advantage_estimator="$ADVANTAGE_ESTIMATOR" \
    trainer.policy.model.path="$MODEL" \
    trainer.placement.colocate_all=true \
    trainer.strategy=fsdp2 \
    trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
    trainer.placement.critic_num_gpus_per_node=$NUM_GPUS \
    trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
    generator.num_inference_engines=$NUM_GPUS \
    generator.inference_engine_tensor_parallel_size=1 \
    trainer.epochs=4 \
    trainer.eval_batch_size=128 \
    trainer.eval_before_train=true \
    trainer.eval_interval=4 \
    trainer.update_epochs_per_batch=$UPDATE_EPOCHS_PER_BATCH \
    trainer.train_batch_size=128 \
    trainer.policy_mini_batch_size=32 \
    trainer.micro_forward_batch_size_per_gpu=8 \
    trainer.micro_train_batch_size_per_gpu=8 \
    trainer.ckpt_interval=10 \
    trainer.max_prompt_length=1024 \
    generator.sampling_params.max_generate_length=3072 \
    trainer.policy.optimizer_config.lr=5.0e-7 \
    trainer.algorithm.use_kl_loss=false \
    generator.backend=$INFERENCE_BACKEND \
    generator.run_engines_locally=true \
    generator.weight_sync_backend=nccl \
    generator.async_engine=true \
    generator.batched=true \
    environment.env_class=aime \
    generator.n_samples_per_prompt=16 \
    generator.gpu_memory_utilization=0.6 \
    trainer.logger="$LOGGER" \
    trainer.project_name="competition_math" \
    trainer.run_name="competition_math_${RUN_NAME_SUFFIX}_${MODEL_NAME}_${UPDATE_EPOCHS_PER_BATCH}_all" \
    trainer.resume_mode=null \
    trainer.ckpt_path="$WORK/ckpts/competition_math_0.5B_ckpt"
'
