#!/bin/bash

#SBATCH --gres=gpu:8
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=8

#SBATCH --job-name=hpo_freezing
#SBATCH --open-mode=append

#SBATCH --error=./slurm_logs/%A_%a_%N_log.err
#SBATCH --output=./slurm_logs/%A_%a_%N_log.out

#SBATCH --partition=<partition>
#SBATCH --time=16:00:00 
#SBATCH --mail-type=FAIL
#SBATCH --exclude=<exclude>

cd <repo path> || exit 1

source .venv/bin/activate || exit 1

export PYTHONPATH=$PWD

if [ -z "$1" ]; then
    echo "Error: group_name argument is required"
    exit 1
fi

if [ -z "$2" ]; then
    echo "Error: dataset argument is required"
    exit 1
fi

group_name=$1
dataset=$2
n_trainable_layers=$3
epochs=$4

srun --unbuffered \
    --ntasks 8 \
    --gpus-per-task 1 \
    --cpus-per-task 8 \
    python experiments/resnet/rank_correlation/grid_search.py \
        --group_name "${group_name}" \
        --n_trainable_layers "$n_trainable_layers" \
        --epochs "$epochs" \
        --dataloader_workers 8 \
        --dataset "$dataset"

# end of file
