#!/bin/bash
#SBATCH --job-name=sft_tulu_32b_to_7b
#SBATCH --output=/scratch/$USER/run_logs/%x_%j.out
#SBATCH --error=/scratch/$USER/run_logs/%x_%j.err
#SBATCH --partition=<partition>
#SBATCH --gpus-per-node=1
#SBATCH --cpus-per-task=16
#SBATCH --export=NONE
#SBATCH --account=<account>
#SBATCH --time=1-00:00:00

# Unified experiment launcher for KD/SFT/DPO pipelines (single-GPU)
# 1x GPU:       srun -c 16 --gres=gpu:1 --partition=<partition> --mem=120GB --pty --time=3:00:00 --account=<account> bash
# CPU-only:     srun -c 16 --partition=<partition> --mem=120GB --pty --time=3:00:00 --account=<account> bash
# Example override of run_dir:
#   bash run_pipeline.sh configs/experiments/dpo_32b_to_7b.yaml --run-dir /tmp/my_run

set -e
set -x

# Get config path and extra args
CONFIG_PATH=${1:-"configs/experiments/sft_32b_to_7b.yaml"}
EXTRA_ARGS="${@:2}"

echo "==============================================="
echo "Distillation Energy Benchmark"
echo "==============================================="
echo "Job ${SLURM_JOB_NAME} (${SLURM_JOB_ID}) started at $(date)"
echo "Running on node: $(hostname)"
echo "Job ID: $SLURM_JOB_ID"
echo "Config: $CONFIG_PATH"
echo "GPU resources: ${CUDA_VISIBLE_DEVICES}"
echo "Extra args: $EXTRA_ARGS"
echo "==============================================="

# Export hardware metadata for energy tracking
export SLURM_JOB_ID=$SLURM_JOB_ID
export SLURM_JOB_NAME=$SLURM_JOB_NAME
export SLURM_NODELIST=$SLURM_NODELIST

# Device Settings
export OMP_NUM_THREADS="${SLURM_CPUS_PER_TASK:-16}"
export MKL_NUM_THREADS="$OMP_NUM_THREADS"
export OPENBLAS_NUM_THREADS="$OMP_NUM_THREADS"
export NUMEXPR_NUM_THREADS="$OMP_NUM_THREADS"

# Huggingface Settings:
export HF_HOME=${HF_HOME:-/scratch/$USER/hf_cache}
export HF_DATASETS_OFFLINE=1
export HF_HUB_OFFLINE=1
export TRANSFORMERS_OFFLINE=1

# Memory optimization
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# wandb settings
export WANDB_MODE=offline
export WANDB_DIR=$SCRATCH/wandb
export WANDB_CACHE_DIR=$SCRATCH/.cache/wandb
export WANDB_CONFIG_DIR=$SCRATCH/.config/wandb
mkdir -p "$WANDB_DIR" "$WANDB_CACHE_DIR" "$WANDB_CONFIG_DIR"
export WANDB_PROJECT="${WANDB_PROJECT:-$SLURM_JOB_NAME}"

# Load modules
module load StdEnv/2023
module load gcc python/3.11 arrow/21
source .venv/bin/activate

# Show which Python we're actually running and whether torch is visible
echo "Python in batch job:"
which python
python -c "import torch; print('torch version in batch job =', torch.__version__)"

# Run experiment or data script
python run_experiment.py --config "$CONFIG_PATH" $EXTRA_ARGS

EXIT_CODE=$?

echo "==============================================="
echo "Job completed at $(date)"
echo "Exit code: $EXIT_CODE"
echo "==============================================="

exit $EXIT_CODE
