#!/bin/bash
#
#SBATCH -N 1
#SBATCH -t 1-00:00
#SBATCH -o ./log/%j.out
#SBATCH -e ./log/%j.err

source ./env.sh

if [ -z $VAR ]; then
    VAR=0.0
fi

if [ -z $LR ]; then
    LR=1e-4
fi

if [ -z $DROPOUT ]; then
    DROPOUT=0
fi

if [ -z $N_EMBD ]; then
    N_EMBD=32
fi

if [ -z $N_LAYER ]; then
    N_LAYER=4
fi

if [ -z $N_HEAD ]; then
    N_HEAD=4
fi

if [ -z $N_EPOCHS ]; then
    N_EPOCHS=100
fi

python3 train2.py --env chain \
    --var $VAR \
    --n_epochs $N_EPOCHS \
    --lr $LR \
    --dropout $DROPOUT \
    --n_embd $N_EMBD \
    --n_layer $N_LAYER \
    --n_head $N_HEAD \
    --log wandb
