#!/bin/bash

#SBATCH -N 1
#SBATCH -c 4
#SBATCH --time=2:00:00
#SBATCH --gres=gpu:1
#SBATCH --mem=16G

set -ux

module load Python/3.9.6-GCCcore-11.2.0
module load CUDA/11.6.0
source env/bin/activate

export WANDB_MODE=offline

python3 train_jax.py --dataset_root /data/cifar $@
