#!/bin/bash

#SBATCH --gres=gpu:8
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=8

#SBATCH --job-name=hpo_freezing
#SBATCH --open-mode=append

#SBATCH --error=./slurm_logs/%j_%a_%N_log.err
#SBATCH --output=./slurm_logs/%j_%a_%N_log.out

#SBATCH --partition=<partition>
#SBATCH --time=24:00:00 
#SBATCH --mail-type=FAIL
#SBATCH --exclude=<exclude>

cd <repo path> || exit 1

source .venv/bin/activate || exit 1

export PYTHONPATH=$PWD

if [ -z "$1" ]; then
    echo "Error: group_name argument is required"
    exit 1
fi

if [ -z "$2" ]; then
    echo "Error: n_trainable_layers argument is required"
    exit 1
fi

group_name=$1
n_trainable_layers=$2

# --unbuffered: enables multiple processes to log output instantly but can slow I/O overall
srun --unbuffered \
    --ntasks 8 \
    --gpus-per-task 1 \
    --cpus-per-task 8 \
    python experiments/resnet/rank_correlation/grid_search.py \
        --group_name $group_name \
        --n_trainable_layers $n_trainable_layers \
        --dataloader_workers 8 \
        --dataset c100 \
        --model_name vit_b_16 \
        --batch_size 256

# end of file
