#!/bin/sh
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
PYTHON_FILE="${SCRIPT_DIR}/main.py"
DEVICE_IDX="0"
DATASET="color_rot_mnist"
MODEL_TYPE="cgeconv"
ENTITY="{YOUR WANDB ID}"
TASK="bias"
LATENT_DIM="256"
TRAIN_BATCH_SIZE="64"
TEST_BATCH_SIZE="256"
NUM_EPOCH="1"
MAX_STEPS="0"
SAVE_STEPS="100000000"
SEED="1"
LR="1e-4"
trap "exit" INT
#            torchrun --standalone --nproc_per_node=8 $PYTHON_FILE \
#            CUDA_VISIBLE_DEVICES=$DEVICE_IDX python $PYTHON_FILE \

for MODEL in $MODEL_TYPE
do
  for EPOCH in $NUM_EPOCH
  do
    for BATCH in $TRAIN_BATCH_SIZE
    do
      for seed in $SEED
      do
        for lr in $LR
        do
          CUDA_VISIBLE_DEVICES=$DEVICE_IDX python $PYTHON_FILE \
          --device_idx $DEVICE_IDX \
          --dataset $DATASET \
          --project_name ${TASK}_${DATASET}_SUPPLE_TEST \
          --entity $ENTITY \
          --model_type $MODEL \
          --task $TASK \
          --latent_dim $LATENT_DIM \
          --per_gpu_train_batch_size $BATCH \
          --test_batch_size $TEST_BATCH_SIZE \
          --num_epoch $EPOCH \
          --max_steps $MAX_STEPS \
          --save_steps $SAVE_STEPS \
          --seed $seed --lr_rate $lr --weight_decay 0.0 \
          --do_train --do_eval
        done
      done
    done
  done
done