#!/bin/bash

#SBATCH --gres=gpu:8
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=8

#SBATCH --job-name=mf_hpo_n_trainable
#SBATCH --open-mode=append

#SBATCH --error=./slurm_logs/%j_%a_%N_log.err
#SBATCH --output=./slurm_logs/%j_%a_%N_log.out

#SBATCH --partition=<partition>
#SBATCH --time=24:00:00 
#SBATCH --mail-type=FAIL
#SBATCH --exclude=<exclude>

cd <repo path> || exit 1

source .venv/bin/activate || exit 1

export PYTHONPATH=$PWD

if [ -z "$1" ]; then
    echo "Error: group_name argument is required"
    exit 1
fi

if [ -z "$2" ]; then
    echo "Error: dataset argument is required"
    exit 1
fi

group_name=$1
dataset=$2
# --unbuffered: enables multiple processes to log output instantly but can slow I/O overall
srun --unbuffered \
    --ntasks 8 \
    --gpus-per-task 1 \
    --cpus-per-task 8 \
    python experiments/resnet/hpo/n_trainable.py \
        --group_name $group_name \
        --num_dataloader_workers 8 \
        --dataset $dataset

# end of file
