#!/bin/bash

# Script to submit SLURM jobs for all 5 tasks with both qwen2.5-7b-instruct and qwen-1.5b models
# Based on simple_equation.sh template and command_llama.txt configurations

echo "Submitting SLURM jobs for all 5 reasoning tasks (2 model sizes each)..."

# Function to create and submit a SLURM job
# Parameters: task_name config_name gpus cpus tensor_parallel micro_batch_size mini_batch_size model_path model_size train_batch_size additional_params
submit_task() {
    local task_name=$1
    local config_name=$2
    local gpus=$3
    local cpus=$4
    local tensor_parallel=$5
    local micro_batch_size=$6
    local mini_batch_size=$7
    local model_path=$8
    local model_size=$9
    local train_batch_size=${10}
    local additional_params=${11}

    local script_file="/nlp/scr/qinanyu/rl-explanations/temp/slurm_${task_name}_${model_size}.sh"
    
    cat > "$script_file" << EOF
#!/bin/bash
#SBATCH --partition=sphinx --qos=normal
#SBATCH --account=nlp
#SBATCH --cpus-per-task=${cpus}
#SBATCH --exclude=sphinx[1,2,4,6,5,7]
#SBATCH --gres=gpu:${gpus}
#SBATCH --job-name=${model_size}-${task_name}
#SBATCH --mem=560G
#SBATCH --open-mode=append
#SBATCH --output=/nlp/scr/qinanyu/rl-explanations/bash_output/grpo-${task_name}-${model_size}.out
#SBATCH --time=14-0

# Unique per job
unset ROCR_VISIBLE_DEVICES 
export HYDRA_FULL_ERROR=0 

#export JOB_TAG="\${SLURM_JOB_ID:-\${LSB_JOBID:-jid.\$(id -u)-\$\$-\$(date +%s)}}"
#
## Per-job dirs
#export BASE="/dev/shm/\$USER/vllm.\$JOB_TAG"
#export TORCH_EXTENSIONS_DIR="\${BASE}/torch_ext"
#export FLASHINFER_JIT_DIR="\${BASE}/flashinfer_jit"
#export CUDA_CACHE_PATH="\${BASE}/cuda_cache"
#export TMPDIR="\${BASE}/nlp/scr/qinanyu/rl-explanations/temp"
#
## Create them
#mkdir -p "\$TORCH_EXTENSIONS_DIR" "\$FLASHINFER_JIT_DIR" "\$CUDA_CACHE_PATH" "\$TMPDIR"
#chmod 700 "\$BASE" "\$TMPDIR"
#
## Optional
#export RAY_TMPDIR="/nlp/scr/qinanyu/ray_st"; mkdir -p "\$RAY_TMPDIR"
#
## ------- Pick arch for the current node -------
#export TORCH_CUDA_ARCH_LIST="8.0;9.0a" 
#
## If nvcc isn't present, avoid dead CUDA_HOME
#command -v nvcc >/dev/null 2>&1 || unset CUDA_HOME
#
## Clean up at exit
## trap 'rm -rf "\$BASE"' EXIT
#
## Debug (optional)
#echo "BASE=\$BASE"
#echo "TMPDIR=\$TMPDIR"
#echo "TORCH_EXTENSIONS_DIR=\$TORCH_EXTENSIONS_DIR"
#echo "FLASHINFER_JIT_DIR=\$FLASHINFER_JIT_DIR"
#echo "CUDA_CACHE_PATH=\$CUDA_CACHE_PATH"
#echo "TORCH_CUDA_ARCH_LIST=\$TORCH_CUDA_ARCH_LIST"

echo "starting training"

python -u /nlp/scr/qinanyu/rl-explanations/trainers/verl_train.py \\
--config-name ${config_name} \\
--config-path configs/grpo \\
actor_rollout_ref.model.path=${model_path} \\
trainer.project_name=grpo_${model_size} \\
trainer.experiment_name=${task_name}_informativeness_ \\
actor_rollout_ref.rollout.tensor_model_parallel_size=${tensor_parallel} \\
trainer.n_gpus_per_node=${gpus} \\
ray_init.num_cpus=${cpus} \\
reasoning_gym.val_path=trainers/grpo/${model_size}-instruct/val_${task_name} \\
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${micro_batch_size} \\
actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\
data.train_batch_size=${train_batch_size}${additional_params} \\
trainer.val_before_train=False \\
trainer.save_freq=-1

rm -rf \$BASE
EOF

    echo "Submitting job for ${task_name} (${model_size})..."
    sbatch "$script_file"
    
    # Clean up temporary script
    rm "$script_file"
}

# =====================================================================
# CONFIGURATION: Edit these arrays to control which tasks and models to run
# =====================================================================

# Tasks to run (comment out or remove tasks you don't want to run)
TASKS=(
    "mini_sudoku"
    #"spiral_matrix"
    #"family_relationships"
    #"simple_equations"
    #"futoshiki"
    #"knight_swap"
    #"rush_hour"
    #"basic_arithmetic"
    #"puzzle24"
    #"simple_geometry"
    #"arc_1d"
    #"rotate_matrix"
)

# Model configurations: each entry contains "model_path|model_size|gpus|cpus|tensor_parallel|micro_batch_size|mini_batch_size|train_batch_size"
MODEL_CONFIGS=(
    # 7B model configuration (currently commented out)
    #"/nlp/scr/qinanyu/models/qwen2.5-7b-instruct|q7b|4|4|2|32|128|128"
    
    # 1.5B model configuration
    #"/nlp/scr/qinanyu/models/qwen2.5-1.5b-instruct|q1.5b|2|2|1|32|128|128"
    "/nlp/scr/qinanyu/models/qwen2.5-3b-instruct|q3b|2|2|1|32|128|128"
)

# =====================================================================
# Submit jobs based on configuration
# =====================================================================

echo "Submitting jobs for ${#TASKS[@]} tasks with ${#MODEL_CONFIGS[@]} model configurations..."

for task in "${TASKS[@]}"; do
    for config in "${MODEL_CONFIGS[@]}"; do
        # Parse the configuration string
        IFS='|' read -r model_path model_size gpus cpus tensor_parallel micro_batch_size mini_batch_size train_batch_size <<< "$config"
        
        echo "=== Submitting $task ($model_size) ==="
        echo "Output file: /nlp/scr/qinanyu/rl-explanations/bash_output/grpo-${task}-${model_size}.out"
        submit_task "$task" "$task" "$gpus" "$cpus" "$tensor_parallel" "$micro_batch_size" "$mini_batch_size" "$model_path" "$model_size" "$train_batch_size" ""
    done
done

total_jobs=$((${#TASKS[@]} * ${#MODEL_CONFIGS[@]}))
echo "All $total_jobs jobs submitted (${#TASKS[@]} tasks × ${#MODEL_CONFIGS[@]} models)! Check status with: squeue -u \$USER"
echo "Monitor outputs in: /nlp/scr/qinanyu/rl-explanations/bash_output/"