#!/bin/bash

# Script to submit SLURM jobs for all 5 tasks with both qwen2.5-7b-instruct and qwen-1.5b models
# Based on simple_equation.sh template and command_llama.txt configurations

echo "Submitting SLURM jobs for all 5 reasoning tasks (2 model sizes each)..."

# Function to create and submit a SLURM job
# Parameters: task_name config_name gpus cpus tensor_parallel micro_batch_size mini_batch_size model_path model_size train_batch_size additional_params
submit_task() {
    local task_name=$1
    local config_name=$2
    local gpus=$3
    local cpus=$4
    local tensor_parallel=$5
    local micro_batch_size=$6
    local mini_batch_size=$7
    local model_path=$8
    local model_size=$9
    local train_batch_size=${10}
    local additional_params=${11}

    local script_file="/nlp/scr/qinanyu/rl-explanations/temp/drslurm_${task_name}_${model_size}.sh"
    
    cat > "$script_file" << EOF
#!/bin/bash
#SBATCH --partition=sphinx --qos=normal
#SBATCH --account=nlp
#SBATCH --cpus-per-task=${cpus}
#SBATCH --exclude=sphinx[1,2,4,5,7]
#SBATCH --gres=gpu:${gpus}
#SBATCH --job-name=d_${model_size}-${task_name}
#SBATCH --mem=560G
#SBATCH --open-mode=append
#SBATCH --output=/nlp/scr/qinanyu/rl-explanations/bash_output/drgrpo-${task_name}-${model_size}.out
#SBATCH --time=14-0

# Unique per job
unset ROCR_VISIBLE_DEVICES 
export HYDRA_FULL_ERROR=1 

export JOB_TAG="\${SLURM_JOB_ID:-\${LSB_JOBID:-jid.\$(id -u)-\$\$-\$(date +%s)}}"

# Per-job dirs
export BASE="/dev/shm/\$USER/vllm.\$JOB_TAG"
export TORCH_EXTENSIONS_DIR="\${BASE}/torch_ext"
export FLASHINFER_JIT_DIR="\${BASE}/flashinfer_jit"
export CUDA_CACHE_PATH="\${BASE}/cuda_cache"
export TMPDIR="\${BASE}/nlp/scr/qinanyu/rl-explanations/temp"

# Create them
mkdir -p "\$TORCH_EXTENSIONS_DIR" "\$FLASHINFER_JIT_DIR" "\$CUDA_CACHE_PATH" "\$TMPDIR"
chmod 700 "\$BASE" "\$TMPDIR"

# Optional
export RAY_TMPDIR="/nlp/scr/qinanyu/ray_st"; mkdir -p "\$RAY_TMPDIR"

# ------- Pick arch for the current node -------
export TORCH_CUDA_ARCH_LIST="8.0;9.0a" 

# If nvcc isn't present, avoid dead CUDA_HOME
command -v nvcc >/dev/null 2>&1 || unset CUDA_HOME

# Clean up at exit
# trap 'rm -rf "\$BASE"' EXIT

# Debug (optional)
echo "BASE=\$BASE"
echo "TMPDIR=\$TMPDIR"
echo "TORCH_EXTENSIONS_DIR=\$TORCH_EXTENSIONS_DIR"
echo "FLASHINFER_JIT_DIR=\$FLASHINFER_JIT_DIR"
echo "CUDA_CACHE_PATH=\$CUDA_CACHE_PATH"
echo "TORCH_CUDA_ARCH_LIST=\$TORCH_CUDA_ARCH_LIST"

echo "starting training"

python -u /nlp/scr/qinanyu/rl-explanations/trainers/verl_train.py \\
--config-name ${config_name} \\
--config-path ../configs/drgrpo \\
actor_rollout_ref.model.path=${model_path} \\
trainer.project_name=drgrpo_${model_size} \\
trainer.experiment_name=${task_name} \\
actor_rollout_ref.rollout.tensor_model_parallel_size=${tensor_parallel} \\
trainer.n_gpus_per_node=${gpus} \\
actor_rollout_ref.rollout.tensor_parallel_size=${tensor_parallel} \\
ray_init.num_cpus=${cpus} \\
reasoning_gym.val_path=trainers/drgrpo/${model_size}-instruct/val_${task_name} \\
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${micro_batch_size} \\
actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\
data.train_batch_size=${train_batch_size}${additional_params}

rm -rf \$BASE
EOF

    echo "Submitting job for ${task_name} (${model_size})..."
    sbatch "$script_file"
    
    # Clean up temporary script
    rm "$script_file"
}

# =====================================================================
# CONFIGURATION: Edit these arrays to control which tasks and models to run
# =====================================================================

# Tasks to run (comment out or remove tasks you don't want to run)
TASKS=(
    #"mini_sudoku"
    "spiral_matrix"
    #"family_relationships"
    #"simple_equations"
    #"futoshiki"
)

# Model configurations: each entry contains "model_path|model_size|gpus|cpus|tensor_parallel|micro_batch_size|mini_batch_size|train_batch_size"
MODEL_CONFIGS=(
    # 7B model configuration (currently commented out)
    #"/nlp/scr/qinanyu/models/qwen2.5-7b-instruct|q7b|4|4|2|16|128|128"
    "/nlp/scr/qinanyu/models/qwen2.5-3b-instruct|q3b|4|4|2|16|128|128"
    # 1.5B model configuration
    #"/nlp/scr/qinanyu/models/qwen2.5-1.5b-instruct|q1.5b|2|2|1|32|128|128"
)

# =====================================================================
# Submit jobs based on configuration
# =====================================================================

echo "Submitting jobs for ${#TASKS[@]} tasks with ${#MODEL_CONFIGS[@]} model configurations..."

for task in "${TASKS[@]}"; do
    for config in "${MODEL_CONFIGS[@]}"; do
        # Parse the configuration string
        IFS='|' read -r model_path model_size gpus cpus tensor_parallel micro_batch_size mini_batch_size train_batch_size <<< "$config"
        
        echo "=== Submitting $task ($model_size) ==="
        echo "Output file: /nlp/scr/qinanyu/rl-explanations/bash_output/drgrpo-${task}-${model_size}.out"
        submit_task "$task" "$task" "$gpus" "$cpus" "$tensor_parallel" "$micro_batch_size" "$mini_batch_size" "$model_path" "$model_size" "$train_batch_size" ""
    done
done

total_jobs=$((${#TASKS[@]} * ${#MODEL_CONFIGS[@]}))
echo "All $total_jobs jobs submitted (${#TASKS[@]} tasks × ${#MODEL_CONFIGS[@]} models)! Check status with: squeue -u \$USER"
echo "Monitor outputs in: /nlp/scr/qinanyu/rl-explanations/bash_output/"