#!/bin/bash

# Script to submit SLURM jobs for all 5 tasks with both qwen2.5-7b-instruct and qwen-1.5b models
# Based on simple_equation.sh template and command_llama.txt configurations

echo "Submitting SLURM jobs for all 5 reasoning tasks (2 model sizes each)..."

# Function to create and submit a SLURM job
# Parameters: task_name model_path model_size gpus micro_batch_size train_batch_size
submit_task() {
    local task_name=$1
    local model_path=$2
    local model_size=$3
    local gpus=$4
    local micro_batch_size=$5
    local train_batch_size=$6
    
    # Set reasonable defaults for missing parameters
    local config_name=$task_name
    local cpus=$gpus
    local tensor_parallel=2
    local mini_batch_size=$micro_batch_size
    local additional_params=""

    local script_file="/nlp/scr/qinanyu/rl-explanations/temp/slurm_${task_name}_${model_size}.sh"
    
    cat > "$script_file" << EOF
#!/bin/bash
#SBATCH --partition=sphinx --qos=normal
#SBATCH --account=nlp
#SBATCH --cpus-per-task=${cpus}
#SBATCH --exclude=sphinx[1,2]
#SBATCH --gres=gpu:${gpus}
#SBATCH --job-name=sft_${model_size}-${task_name}
#SBATCH --mem=560G
#SBATCH --open-mode=append
#SBATCH --output=/nlp/scr/qinanyu/rl-explanations/bash_output/sft-${task_name}-${model_size}.out
#SBATCH --time=14-0

# Unique per job
unset ROCR_VISIBLE_DEVICES 
export HYDRA_FULL_ERROR=0 

export JOB_TAG="\${SLURM_JOB_ID:-\${LSB_JOBID:-jid.\$(id -u)-\$\$-\$(date +%s)}}"


echo "starting training"

torchrun --standalone --nnodes=1 --nproc_per_node=${gpus} /nlp/scr/qinanyu/rl-explanations/trainers/verl_sft_trainer.py \\
--config-name sft_trainer \\
--config-path configs/sft \\
model.partial_pretrain=${model_path} \\
trainer.project_name=sft_${model_size} \\
trainer.experiment_name=${task_name} \\
trainer.n_gpus_per_node=${gpus} \\
data.train_batch_size=${train_batch_size}${additional_params} \\
data.train_files=generate/train_data_generic/o3-mini_gpt-4.1-mini/${task_name}.parquet \\
data.val_files=generate/val_data_generic/o3-mini_gpt-4.1-mini/${task_name}.parquet \\
data.prompt_key=question \\
data.response_key=teacher_thinking_without_answer \\
data.developer_prompt=DeepSeekZero \\
data.micro_batch_size_per_gpu=${micro_batch_size} 

rm -rf \$BASE
EOF

    echo "Submitting job for ${task_name} (${model_size})..."
    sbatch "$script_file"
    
    # Clean up temporary script
    rm "$script_file"
}

# =====================================================================
# CONFIGURATION: Edit these arrays to control which tasks and models to run
# =====================================================================

# Tasks to run (comment out or remove tasks you don't want to run)
TASKS=(
    #"mini_sudoku"
    "spiral_matrix"
    #"family_relationships"
    #"simple_equations"
    #"futoshiki"
)

# Model configurations: each entry contains "model_path|model_size|gpus|micro_batch_size|train_batch_size"
MODEL_CONFIGS=(
    # 7B model configuration (currently commented out)
    #"/nlp/scr/qinanyu/models/qwen2.5-7b-instruct|q7b|1|2|64"
    
    # 1.5B model configuration
    #"/nlp/scr/qinanyu/models/qwen2.5-1.5b-instruct|q1.5b|2|32|128"
    "/nlp/scr/qinanyu/models/qwen2.5-3b-instruct|q3b|1|32|128"
)

# =====================================================================
# Submit jobs based on configuration
# =====================================================================

echo "Submitting jobs for ${#TASKS[@]} tasks with ${#MODEL_CONFIGS[@]} model configurations..."

for task in "${TASKS[@]}"; do
    for config in "${MODEL_CONFIGS[@]}"; do
        # Parse the configuration string
        IFS='|' read -r model_path model_size gpus micro_batch_size train_batch_size <<< "$config"
        
        echo "=== Submitting $task ($model_size) ==="
        echo "Output file: /nlp/scr/qinanyu/rl-explanations/bash_output/sft-${task}-${model_size}.out"
        submit_task "$task" "$model_path" "$model_size" "$gpus" "$micro_batch_size" "$train_batch_size"
    done
done

total_jobs=$((${#TASKS[@]} * ${#MODEL_CONFIGS[@]}))
echo "All $total_jobs jobs submitted (${#TASKS[@]} tasks × ${#MODEL_CONFIGS[@]} models)! Check status with: squeue -u \$USER"
echo "Monitor outputs in: /nlp/scr/qinanyu/rl-explanations/bash_output/"