#!/bin/bash

#SBATCH --gres=gpu:1
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=8

#SBATCH --job-name=hw_metrics
#SBATCH --open-mode=append  # important for multiple processes to share a log file

#SBATCH --error=<repo path>/slurm_logs/%j_%a_%N_log.err
#SBATCH --output=<repo path>/slurm_logs/%j_%a_%N_log.out

#SBATCH --partition=<partition>
#SBATCH --time=00:30:00  
#SBATCH --mail-type=FAIL
#SBATCH --exclude=<exclude>

# Define the models and batch sizes
# Adjust the array size based on the total number of combinations
# For example, with 6 models and 7 batch sizes, you would need --array=0-41
#SBATCH --array=0-41

cd <repo path> || exit 1

source .venv/bin/activate || exit 1

export PYTHONPATH=$PWD

# Define the models and batch sizes
MODELS=(resnet18 resnet34 resnet50 resnet101 resnet152)
BATCH_SIZES=(16 32 64 128 256 512 1024)

# Calculate which model and batch size to use based on the array task ID
NUM_BATCH_SIZES=${#BATCH_SIZES[@]}
MODEL_IDX=$((SLURM_ARRAY_TASK_ID / NUM_BATCH_SIZES))
BATCH_IDX=$((SLURM_ARRAY_TASK_ID % NUM_BATCH_SIZES))

MODEL=${MODELS[$MODEL_IDX]}
BATCH_SIZE=${BATCH_SIZES[$BATCH_IDX]}

echo "Running task $SLURM_ARRAY_TASK_ID: Model=$MODEL, Batch Size=$BATCH_SIZE"

# Run the experiment with the new script
python experiments/resnet/hw_metrics/collect_hw_metrics.py \
    --model $MODEL \
    --batch_size $BATCH_SIZE \
    --num_dataloader_workers 8 \
    --device cuda \
    --warmup_passes 100

echo "Task $SLURM_ARRAY_TASK_ID completed"

# end of file
