#!/bin/bash

# Script to submit SLURM jobs for all 5 tasks with both qwen2.5-7b-instruct and qwen-1.5b models
# Based on simple_equation.sh template and command_llama.txt configurations

echo "Submitting SLURM jobs for all 5 reasoning tasks (2 model sizes each)..."

# Function to create and submit a SLURM job
# Parameters: task_name config_name gpus cpus tensor_parallel micro_batch_size mini_batch_size model_path model_size train_batch_size additional_params task_trained_on
submit_task() {
    local task_name=$1
    local config_name=$2
    local gpus=$3
    local cpus=$4
    local tensor_parallel=$5
    local micro_batch_size=$6
    local mini_batch_size=$7
    local model_path=$8
    local model_size=$9
    local train_batch_size=${10}
    local additional_params=${11}
    local task_trained_on=${12}

    local script_file="/nlp/scr/qinanyu/rl-explanations/temp/slurm_${task_name}_${model_size}.sh"
    
    cat > "$script_file" << EOF
#!/bin/bash
#SBATCH --partition=sphinx --qos=normal
#SBATCH --account=nlp
#SBATCH --cpus-per-task=${cpus}
#SBATCH --exclude=sphinx[1,2,4,5,6]
#SBATCH --gres=gpu:${gpus}
#SBATCH --job-name=${model_size}-${task_name}
#SBATCH --mem=140G
#SBATCH --open-mode=append
#SBATCH --output=/nlp/scr/qinanyu/rl-explanations/bash_output/grpo_post_sft-${task_name}-${model_size}.out
#SBATCH --time=14-0

# Unique per job
unset ROCR_VISIBLE_DEVICES 
export HYDRA_FULL_ERROR=1 

#export JOB_TAG="\${SLURM_JOB_ID:-\${LSB_JOBID:-jid.\$(id -u)-\$\$-\$(date +%s)}}"
#
## Per-job dirs
#export BASE="/dev/shm/\$USER/vllm.\$JOB_TAG"
#export TORCH_EXTENSIONS_DIR="\${BASE}/torch_ext"
#export FLASHINFER_JIT_DIR="\${BASE}/flashinfer_jit"
#export CUDA_CACHE_PATH="\${BASE}/cuda_cache"
#export TMPDIR="\${BASE}/nlp/scr/qinanyu/rl-explanations/temp"
#
## Create them
#mkdir -p "\$TORCH_EXTENSIONS_DIR" "\$FLASHINFER_JIT_DIR" "\$CUDA_CACHE_PATH" "\$TMPDIR"
#chmod 700 "\$BASE" "\$TMPDIR"
#
## Optional
#export RAY_TMPDIR="/nlp/scr/qinanyu/ray_st"; mkdir -p "\$RAY_TMPDIR"
#
## ------- Pick arch for the current node -------
#export TORCH_CUDA_ARCH_LIST="8.0;9.0a" 
#
## If nvcc isn't present, avoid dead CUDA_HOME
#command -v nvcc >/dev/null 2>&1 || unset CUDA_HOME
#
## Clean up at exit
## trap 'rm -rf "\$BASE"' EXIT
#
## Debug (optional)
#echo "BASE=\$BASE"
#echo "TMPDIR=\$TMPDIR"
#echo "TORCH_EXTENSIONS_DIR=\$TORCH_EXTENSIONS_DIR"
#echo "FLASHINFER_JIT_DIR=\$FLASHINFER_JIT_DIR"
#echo "CUDA_CACHE_PATH=\$CUDA_CACHE_PATH"
#echo "TORCH_CUDA_ARCH_LIST=\$TORCH_CUDA_ARCH_LIST"

echo "starting training"

python -u /nlp/scr/qinanyu/rl-explanations/trainers/verl_train.py \\
--config-name ${config_name} \\
--config-path configs/grpo \\
actor_rollout_ref.model.path=${model_path} \\
trainer.project_name=grpo_post_sft_${model_size} \\
trainer.experiment_name=${task_trained_on}_to_${task_name} \\
actor_rollout_ref.rollout.tensor_model_parallel_size=${tensor_parallel} \\
trainer.n_gpus_per_node=${gpus} \\
ray_init.num_cpus=${cpus} \\
reasoning_gym.val_path=trainers/grpo_post_sft/${model_size}-instruct/val_${task_name} \\
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${micro_batch_size} \\
actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\
data.train_batch_size=${train_batch_size}${additional_params}

rm -rf \$BASE
EOF

    echo "Submitting job for ${task_name} (${model_size})..."
    sbatch "$script_file"
    
    # Clean up temporary script
    rm "$script_file"
}

# =====================================================================
# CONFIGURATION: Edit these arrays to control which tasks and models to run
# =====================================================================

# Tasks that models were trained on (for model_path)
# These determine which trained model checkpoints to load
TASKS_TRAINED_ON=(
    #"mini_sudoku"
    "spiral_matrix"
    #"family_relationships" 
    #"simple_equations"
    #"futoshiki"
)

# Tasks to run (for yaml config)
# These determine which task configs to use for evaluation
TASKS_TO_RUN=(
    "mini_sudoku"
    #"spiral_matrix"
    #"family_relationships"
    #"simple_equations" 
    #"futoshiki"
)

# This will create all combinations of:
# - Models trained on TASKS_TRAINED_ON 
# - Evaluated on TASKS_TO_RUN
# For example: model trained on "spiral_matrix" evaluated on "mini_sudoku"

# Model configurations: each entry contains "model_path|model_size|gpus|cpus|tensor_parallel|micro_batch_size|mini_batch_size|train_batch_size"
# Use {TASK_TRAINED_ON} as placeholder that will be replaced with the trained task name for model_path
MODEL_CONFIGS=(
    # 7B model configuration (currently commented out)
    "/nlp/scr/qinanyu/rl-explanations/checkpoints/sft_q3b/{TASK_TRAINED_ON}|q3b|2|2|1|32|128|128"
    #"/nlp/scr/qinanyu/rl-explanations/checkpoints/sft_q7b/{TASK_TRAINED_ON}|q7b|4|4|2|16|128|128"
)

# =====================================================================
# Submit jobs based on configuration
# =====================================================================

echo "Submitting jobs for ${#TASKS_TRAINED_ON[@]} trained tasks × ${#TASKS_TO_RUN[@]} run tasks × ${#MODEL_CONFIGS[@]} model configurations..."

for task_trained_on in "${TASKS_TRAINED_ON[@]}"; do
    for task_to_run in "${TASKS_TO_RUN[@]}"; do
        for config in "${MODEL_CONFIGS[@]}"; do
            # Replace {TASK_TRAINED_ON} placeholder with actual trained task name
            config_with_task=$(echo "$config" | sed "s/{TASK_TRAINED_ON}/$task_trained_on/g")
            
            # Parse the configuration string
            IFS='|' read -r model_path model_size gpus cpus tensor_parallel micro_batch_size mini_batch_size train_batch_size <<< "$config_with_task"
            
            echo "=== Submitting $task_to_run using model trained on $task_trained_on ($model_size) ==="
            echo "Model path: $model_path"
            echo "Output file: /nlp/scr/qinanyu/rl-explanations/bash_output/grpo_post_sft-${task_to_run}-from_${task_trained_on}-${model_size}.out"
            submit_task "$task_to_run" "$task_to_run" "$gpus" "$cpus" "$tensor_parallel" "$micro_batch_size" "$mini_batch_size" "$model_path" "$model_size" "$train_batch_size" "" "$task_trained_on"
        done
    done
done

total_jobs=$((${#TASKS_TRAINED_ON[@]} * ${#TASKS_TO_RUN[@]} * ${#MODEL_CONFIGS[@]}))
echo "All $total_jobs jobs submitted (${#TASKS_TRAINED_ON[@]} trained tasks × ${#TASKS_TO_RUN[@]} run tasks × ${#MODEL_CONFIGS[@]} models)! Check status with: squeue -u \$USER"
echo "Monitor outputs in: /nlp/scr/qinanyu/rl-explanations/bash_output/"