#!/bin/bash

seeds='321587 26981 156685 296144 572116' 
# dataset='MNIST' # ='MNIST CIFAR10 CIFAR100 TinyImageNet'
case $SLURM_ARRAY_TASK_ID in
    0)
        num_samples_train=1
        momentum=0.0
        ;;
    1)
        num_samples_train=2
        momentum=0.0
        ;;
    2)
        num_samples_train=4
        momentum=0.0
        ;;
    3) 
        num_samples_train=8
        momentum=0.0
        ;;
    4)
        num_samples_train=16
        momentum=0.0
        ;;
    5)
        num_samples_train=32
        momentum=0.0
        ;;
    6) 
        num_samples_train=64
        momentum=0.0
        ;;
    7)
        num_samples_train=128
        momentum=0.0
        ;;
    10)
        num_samples_train=1
        momentum=0.9
        ;;
    11)
        num_samples_train=2
        momentum=0.9
        ;;
    12)
        num_samples_train=4
        momentum=0.9
        ;;
    13) 
        num_samples_train=8
        momentum=0.9
        ;;
    14)
        num_samples_train=16
        momentum=0.9
        ;;
    15)
        num_samples_train=32
        momentum=0.9
        ;;
    16) 
        num_samples_train=64
        momentum=0.9
        ;;
    17)
        num_samples_train=128
        momentum=0.9
        ;;
    *)
        num_samples_train=1
        ;;
esac
models='MLPIVI'
parametrizations='Standard MaximalUpdate' 
num_samples_test=256 #TODO: Should this be much lower?
# momentum_list='0.0 0.9'

num_workers=0
pin_memory=True
tune_batch_size=False
tune_lr=False

case $dataset in
    MNIST)
        ood_dataset='FashionMNIST'
        learning_rates='0.003 0.01 0.03 0.1 0.3' #'0.003 0.01 0.03 0.1 0.3'
        batch_sizes='64'
        batch_size_test=1024
        max_epochs=10
        ;;  
    CIFAR10)
        ood_dataset='CIFAR10C'
        lr=0.1
        batch_size_test=1024
        max_epochs=20
        ;;   
    CIFAR100)
        ood_dataset='CIFAR100C'
        lr=0.1
        batch_size_test=1024
        max_epochs=50
        ;;  
    TinyImageNet)
        ood_dataset='TinyImageNetC'
        lr=0.1
        batch_size_test=1024
        max_epochs=50
        ;; 
    *)
    ood_dataset=None
    lr=0.1
    batch_size=1024
    ;;
esac

for model in $models
do
    for parametrization in $parametrizations
    do
        for batch_size in $batch_sizes
        do
            for lr in $learning_rates
            do
                for momentum in $momentum_list
                do
                    for seed in $seeds
                    do
                        python train.py \
                        --seed $seed \
                        --dataset $dataset \
                        --ood_dataset $ood_dataset \
                        --model $model \
                        --parametrization $parametrization \
                        --num_samples_train $num_samples_train \
                        --num_samples_test $num_samples_test \
                        --max_epochs $max_epochs \
                        --lr $lr \
                        --momentum $momentum \
                        --nesterov False \
                        --batch_size $batch_size \
                        --batch_size_test $batch_size_test \
                        --num_workers $num_workers \
                        --pin_memory $pin_memory \
                        --tune_lr $tune_lr \
                        --tune_batch_size $tune_batch_size
                    done
                done
            done
        done
    done
done