#!/bin/bash
#SBATCH --job-name=dpo_qwen25_05B
#SBATCH --account=<your-account-name>
#SBATCH --partition=<your-gpu-partition>
#SBATCH --gres=gpu:a100:1
#SBATCH --time=100:00:00
#SBATCH --mem=160G
#SBATCH --output=dpo_qwen25_05B_%j.out
#SBATCH --error=dpo_qwen25_05B_%j.err
#SBATCH --mail-type=BEGIN,END,FAIL
#SBATCH --mail-user=<your-email@domain.com>

# Print some information about the job
echo "Job ID: $SLURM_JOB_ID"
echo "Hostname: $(hostname)"
echo "Start time: $(date)"

# Initialize conda
source /path/to/your/miniconda3/etc/profile.d/conda.sh

eval "$(conda shell.bash hook)"

# Activate conda environment
conda activate <your-conda-env>

export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"

# Run your Python script
python3 src/dpo_trainer.py --model_name /path/to/models/Qwen2.5-0.5B-Base \
    --per_device_train_batch_size 6 \
    --save_steps 250 \
    --output_dir Qwen_0.5B_dpo \
    --num_train_epochs 1 \
    --run_name Qwen25_0.5B_DPO_gsm8k_reasoner \
    --learning_rate 1e-5 \
    --weight_decay 0.01 \
    --warmup_steps 100 \
    --max_prompt_length 2000 \
    --max_completion_length 500 \
    --gradient_accumulation_steps 1 \
    --training_data_directory data/dpo_data_gsm8k \

# Print completion time
echo "End time: $(date)"