#!/bin/bash
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=8
#SBATCH --gres=gpu:4
#SBATCH --time=42:00:00
#SBATCH --mem=720G
#SBATCH --exclusive
#SBATCH --account=your_slurm_account
#SBATCH --licenses=walrus:1,octopus:1,narwhal:1,cat:1
#
# TinyZero Countdown Task - Reproducing TinyZero results with SkyRL
# Based on: https://github.com/Jiayi-Pan/TinyZero
#
# Usage:
#   Default:     sbatch sbatch_tinyzero_countdown.sh
#   Custom model: sbatch --export=ALL,MODEL=Qwen/Qwen2.5-1.5B-Instruct sbatch_tinyzero_countdown.sh

# Load modules - must match setup_skyrl_env.sh
module load release/25.06 GCCcore/13.3.0
module load Python/3.12.3
module load CUDA/12.6.0
module load NCCL/2.22.3-CUDA-12.6.0

export XDG_CACHE_HOME=$WORK/.cache

# Set WORK to your workspace directory before running
# export WORK=/path/to/your/workspace
source $WORK/skyrl-venv/bin/activate

set -a; source "$HOME/SkyRL/skyrl-train/experiment-slurm-scripts/.env"; set +a

cd $HOME/SkyRL/skyrl-train

export HF_HOME="$WORK/hub"
export PYTHONFAULTHANDLER=1
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export TRITON_CACHE_DIR="$WORK/triton_cache"
# Enable only when needed:
# export NCCL_DEBUG=INFO
# export CUDA_LAUNCH_BLOCKING=1

# Login to Hugging Face
huggingface-cli login --token $HF_TOKEN

export DATA_DIR="$WORK/data/countdown"
export NUM_GPUS=4
export LOGGER="wandb"  # or "console"
export INFERENCE_BACKEND="vllm"

# Model selection - default to Qwen 3B
export MODEL="Qwen/Qwen2.5-3B"
export MODEL_NAME="Qwen2.5-3B"
export LEARNING_RATE=1e-6
# TinyZero-inspired hyperparameters
# Based on TinyZero's PPO configuration adapted for SkyRL
#export ADVANTAGE_ESTIMATOR="log_reinforce"  # TinyZero uses PPO, SkyRL uses log_reinforce or grpo
#export ADVANTAGE_ESTIMATOR="grpo"  # TinyZero uses PPO, SkyRL uses log_reinforce or grpo
#export ADVANTAGE_ESTIMATOR="log_reinforce"
#export ADVANTAGE_ESTIMATOR="rloo"
#export ADVANTAGE_ESTIMATOR="cancel_reinforce"
#export ADVANTAGE_ESTIMATOR="rejection_sampling"  # Only reinforces positive outcomes (A^- = 0)
export ADVANTAGE_ESTIMATOR="a_reinforce"  # Adaptive advantage based on rho_hat threshold
export A_REINFORCE_ZERO_RHO_ADV="${A_REINFORCE_ZERO_RHO_ADV:--0.15}"  # Advantage value when rho ≈ 0 (only for a_reinforce)
export A_REINFORCE_WARMUP_STEPS="${A_REINFORCE_WARMUP_STEPS:-0}"  # Number of steps to use 0 for zero_rho_advantage before switching to configured value

export RUN_NAME_SUFFIX="tinyzero_countdown"
export UPDATE_EPOCHS_PER_BATCH=1
export FORMAT_SCORE=0.0
export CLIP_GRAD_NORM=0.5

# TinyZero uses smaller batch sizes and higher learning rates for actor
# Adapting to SkyRL's configuration style
export BATCH_SIZE=256  # TinyZero default
export POLICY_MINI_BATCH_SIZE=64  # TinyZero default
export MICRO_TRAIN_BATCH_SIZE_PER_GPU=8  # TinyZero default
srun --ntasks=1 --nodes=1 --gres=gpu:4 --export=ALL bash -c '
  host=$(hostname)
  hostlist=($(scontrol show hostnames "$SLURM_JOB_NODELIST"))
  for i in "${!hostlist[@]}"; do
    if [[ "${hostlist[$i]}" == "$host" ]]; then
      node_rank=$i
      break
    fi
  done

  # --- Start local Ray (single node) ---
  export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n1)
  export MASTER_PORT=29500
  export RAY_TMPDIR="$WORK/ray"

  # Disable Ray dashboard completely to avoid OpenTelemetry compatibility issues
  export RAY_DASHBOARD_AGENT_ENABLED=0

  ulimit -n 65536 || true
  ray stop --force >/dev/null 2>&1 || true
  ray start --head --port=6379 --node-ip-address="$MASTER_ADDR" --num-gpus=$NUM_GPUS --include-dashboard=false --disable-usage-stats

  # Wait for Ray to fully initialize
  sleep 10

  # Test Ray before proceeding (with proper shutdown)
  python - <<'\''PY'\''
import ray
import time
ray.init(address="auto")
print("GPUs:", ray.cluster_resources().get("GPU", 0))
ray.shutdown()
time.sleep(2)
PY

  # Give Ray time to stabilize after test disconnection
  sleep 5

  echo "Starting TinyZero Countdown training on $host with node_rank=$node_rank"
  echo "MASTER_ADDR: $MASTER_ADDR"
  echo "MASTER_PORT: $MASTER_PORT"
  echo "NUM_GPUS: $NUM_GPUS"
  echo "MODEL: $MODEL"
  echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
  nvidia-smi

  # --- Prepare Countdown data once (idempotent) ---
  # This downloads from Jiayi-Pan/Countdown-Tasks-3to4 and processes for SkyRL
  # Uncomment to regenerate dataset:
  # mkdir -p "$DATA_DIR"
  # python $HOME/SkyRL/data/prepare_countdown.py \
  #   --output_dir "$DATA_DIR" \
  #   --train_size 327680 \
  #   --test_size 1024 \
  #   --template_type qwen-instruct

  # --- Run TinyZero Countdown Training ---
  echo "Using advantage estimator: $ADVANTAGE_ESTIMATOR"
  echo "Training on Countdown task with model: $MODEL"

  python -m skyrl_train.entrypoints.main_base \
    data.train_data="['\''$DATA_DIR/train.parquet'\'']" \
    data.val_data="['\''$DATA_DIR/validation.parquet'\'']" \
    trainer.algorithm.advantage_estimator="$ADVANTAGE_ESTIMATOR" \
    trainer.policy.model.path="$MODEL" \
    trainer.placement.colocate_all=true \
    trainer.strategy=fsdp2 \
    trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
    trainer.placement.critic_num_gpus_per_node=$NUM_GPUS \
    trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
    generator.num_inference_engines=$NUM_GPUS \
    generator.inference_engine_tensor_parallel_size=1 \
    trainer.epochs=15 \
    trainer.eval_batch_size=$BATCH_SIZE \
    trainer.eval_before_train=true \
    trainer.eval_interval=4 \
    trainer.update_epochs_per_batch=$UPDATE_EPOCHS_PER_BATCH \
    trainer.train_batch_size=$BATCH_SIZE \
    trainer.policy_mini_batch_size=$POLICY_MINI_BATCH_SIZE \
    trainer.micro_forward_batch_size_per_gpu=$MICRO_TRAIN_BATCH_SIZE_PER_GPU \
    trainer.micro_train_batch_size_per_gpu=$MICRO_TRAIN_BATCH_SIZE_PER_GPU \
    trainer.ckpt_interval=10 \
    trainer.max_prompt_length=512 \
    trainer.policy.optimizer_config.max_grad_norm=$CLIP_GRAD_NORM \
    generator.sampling_params.max_generate_length=1024 \
    trainer.policy.optimizer_config.lr=$LEARNING_RATE \
    trainer.algorithm.use_kl_loss=false \
    +trainer.algorithm.a_reinforce_zero_rho_advantage=$A_REINFORCE_ZERO_RHO_ADV \
    +trainer.algorithm.a_reinforce_warmup_steps=$A_REINFORCE_WARMUP_STEPS \
    generator.backend=$INFERENCE_BACKEND \
    generator.run_engines_locally=true \
    generator.weight_sync_backend=nccl \
    generator.async_engine=true \
    generator.batched=true \
    environment.env_class=countdown \
    +environment.skyrl_gym.countdown.format_score=$FORMAT_SCORE \
    generator.n_samples_per_prompt=16 \
    generator.gpu_memory_utilization=0.4 \
    trainer.logger="$LOGGER" \
    trainer.project_name="tinyzero_countdown" \
    trainer.run_name="countdown_${RUN_NAME_SUFFIX}_${MODEL_NAME}_${ADVANTAGE_ESTIMATOR}_${FORMAT_SCORE}_${CLIP_GRAD_NORM}" \
    trainer.resume_mode=null \
    trainer.ckpt_path="$WORK/ckpts/tinyzero_countdown_${MODEL_NAME}_${ADVANTAGE_ESTIMATOR}_${FORMAT_SCORE}_${BATCH_SIZE}_lr${LEARNING_RATE}/latest.ckpt" \
'
