#!/bin/bash

dataset='TwoMoons' 
parametrizations='Standard' 
num_workers=0
pin_memory=True

case $dataset in
    TwoMoons)
        lr=0.1
        momentum=0.0
        batch_size=64
        batch_size_test=2096
        epochs=10000
        case $SLURM_ARRAY_TASK_ID in
            0)
                seeds='321587'
                models='MLPIVI'
                ;;
            1)
                seeds='26981'
                models='MLPIVI'
                ;;
            2)
                seeds='156685'
                models='MLPIVI'
                ;;
            3)
                seeds='296144'
                models='MLPIVI'
                ;;
            4)
                seeds='572116'
                models='MLPIVI'
                ;;
            10)
                seeds='321587'
                models='MLPIVITemperatureScaling'
                ;;
            11)
                seeds='26981'
                models='MLPIVITemperatureScaling'
                ;;
            12)
                seeds='156685'
                models='MLPIVITemperatureScaling'
                ;;
            13)
                seeds='296144'
                models='MLPIVITemperatureScaling'
                ;;
            14)
                seeds='572116'
                models='MLPIVITemperatureScaling'
                ;;
            20)
                seeds='321587'
                models='MLPIVITheoreticalScaling'
                ;;
            21)
                seeds='26981'
                models='MLPIVITheoreticalScaling'
                ;;
            22)
                seeds='156685'
                models='MLPIVITheoreticalScaling'
                ;;
            23)
                seeds='296144'
                models='MLPIVITheoreticalScaling'
                ;;
            24)
                seeds='572116'
                models='MLPIVITheoreticalScaling'
                ;;
            *)
                seeds='321587 26981 156685 296144 572116'
                models=''
                ;;
        esac
        ;;   
    *)
    pretrained=False
    freeze_pretrained_weights=False
    ;;
esac

for model in $models
do
    for parametrization in $parametrizations
    do
        for seed in $seeds
        do
            python train.py \
            --seed $seed \
            --dataset $dataset \
            --ood_dataset $ood_dataset \
            --model $model \
            --parametrization $parametrization \
            --max_epochs $epochs \
            --lr $lr \
            --momentum $momentum \
            --batch_size $batch_size \
            --batch_size_test $batch_size_test \
            --num_workers $num_workers \
            --pin_memory $pin_memory
        done
    done
done