#!/bin/bash
#SBATCH --job-name=benchmark
#SBATCH --output=/scratch/$USER/run_logs/%x_%j.out
#SBATCH --error=/scratch/$USER/run_logs/%x_%j.err
#SBATCH --partition=<partition>
#SBATCH --gres=gpu:1
#SBATCH --cpus-per-task=16                                                               
#SBATCH --mem=120GB
#SBATCH --export=NONE
#SBATCH --account=<account>
#SBATCH --time=12:00:00

# Unified experiment launcher for KD/SFT/DPO pipelines (single-GPU)
# 1x GPU:       srun -c 16 --gres=gpu:1 --partition=<partition> --mem=120GB --pty --time=3:00:00 --account=<account> bash
# CPU-only:     srun -c 16 --partition=<partition> --mem=120GB --pty --time=3:00:00 --account=<account> bash
# Example override of run_dir:
#   bash run_pipeline.sh configs/experiments/dpo_32b_to_1b.yaml --run-dir /tmp/my_run

set -e
set -x

# Get config path and extra args
CONFIG_PATH=${1:-"configs/experiments/kd_32b_to_1b.yaml"}
EXTRA_ARGS="${@:2}"

echo "==============================================="
echo "Distillation Energy Benchmark"
echo "==============================================="
echo "Job ${SLURM_JOB_NAME} (${SLURM_JOB_ID}) started at $(date)"
echo "Running on node: $(hostname)"
echo "Job ID: $SLURM_JOB_ID"
echo "Config: $CONFIG_PATH"
echo "GPU resources: ${CUDA_VISIBLE_DEVICES}"
echo "Extra args: $EXTRA_ARGS"
echo "==============================================="

# Export hardware metadata for energy tracking
export SLURM_JOB_ID=$SLURM_JOB_ID
export SLURM_JOB_NAME=$SLURM_JOB_NAME
export SLURM_NODELIST=$SLURM_NODELIST

# Device Settings
export OMP_NUM_THREADS="${SLURM_CPUS_PER_TASK:-16}"
export MKL_NUM_THREADS="$OMP_NUM_THREADS"
export OPENBLAS_NUM_THREADS="$OMP_NUM_THREADS"
export NUMEXPR_NUM_THREADS="$OMP_NUM_THREADS"

# Memory optimization
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# Load modules
# module load gcc arrow/18.1.0
source .venv/bin/activate

# Show which Python we're actually running and whether torch is visible
echo "Python in batch job:"
which python
python -c "import torch; print('torch version in batch job =', torch.__version__)"

# Run experiment or data script
python run_experiment.py --config "$CONFIG_PATH" $EXTRA_ARGS

EXIT_CODE=$?

echo "==============================================="
echo "Job completed at $(date)"
echo "Exit code: $EXIT_CODE"
echo "==============================================="

exit $EXIT_CODE
