#!/bin/bash

# dataset='CIFAR10' # ='MNIST CIFAR10 CIFAR100 TinyImageNet'
parametrizations='Standard MaximalUpdate' 
num_workers=4
pin_memory=True

case $dataset in
    MNIST)
        seeds='321587 26981 156685 296144 572116'
        ood_dataset='MNISTC'
        pretrained=False
        freeze_pretrained_weights=False
        lr=0.005
        momentum=0.9
        batch_size=128
        batch_size_test=1024
        max_epochs=200
        case $SLURM_ARRAY_TASK_ID in
            0)
                seeds='321587'
                models='LeNet5 LeNet5TemperatureScaling LeNet5LaplaceLastLayer LeNet5LaplaceLastLayerGridSearch'
                ;;
            1)
                seeds='26981'
                models='LeNet5 LeNet5TemperatureScaling LeNet5LaplaceLastLayer LeNet5LaplaceLastLayerGridSearch'
                ;;
            2)
                seeds='156685'
                models='LeNet5 LeNet5TemperatureScaling LeNet5LaplaceLastLayer LeNet5LaplaceLastLayerGridSearch'
                ;;
            3)
                seeds='296144'
                models='LeNet5 LeNet5TemperatureScaling LeNet5LaplaceLastLayer LeNet5LaplaceLastLayerGridSearch'
                ;;
            4)
                seeds='572116'
                models='LeNet5 LeNet5TemperatureScaling LeNet5LaplaceLastLayer LeNet5LaplaceLastLayerGridSearch'
                ;;
            10)
                seeds='321587'
                models='LeNet5WeightSpaceVIDiagonal'
                ;;
            11)
                seeds='26981'
                models='LeNet5WeightSpaceVIDiagonal'
                ;;
            12)
                seeds='156685'
                models='LeNet5WeightSpaceVIDiagonal'
                ;;
            13)
                seeds='296144'
                models='LeNet5WeightSpaceVIDiagonal'
                ;;
            14)
                seeds='572116'
                models='LeNet5WeightSpaceVIDiagonal'
                ;;
            20)
                seeds='321587'
                models='LeNet5IVILowRank'
                ;;
            21)
                seeds='26981'
                models='LeNet5IVILowRank'
                ;;
            22)
                seeds='156685'
                models='LeNet5IVILowRank'
                ;;
            23)
                seeds='296144'
                models='LeNet5IVILowRank'
                ;;
            24)
                seeds='572116'
                models='LeNet5IVILowRank'
                ;;
            40)
                seeds='321587'
                models='LeNet5SWAG'
                ;;
            41)
                seeds='26981'
                models='LeNet5SWAG'
                ;;
            42)
                seeds='156685'
                models='LeNet5SWAG'
                ;;
            43)
                seeds='296144'
                models='LeNet5SWAG'
                ;;
            44)
                seeds='572116'
                models='LeNet5SWAG'
                ;;
            99)
                seeds='321587 26981 156685 296144 572116'
                models='LeNet5Ensemble'
                ;;
            *)
                seeds='321587 26981 156685 296144 572116'
                models=''
                ;;
        esac
        ;;  
    CIFAR10)
        ood_dataset='CIFAR10C'
        pretrained=True
        freeze_pretrained_weights=False
        lr=0.005
        momentum=0.9
        batch_size=128
        batch_size_test=128
        max_epochs=200
        case $SLURM_ARRAY_TASK_ID in
            0)
                seeds='321587'
                models='ResNet34 ResNet34TemperatureScaling ResNet34LaplaceLastLayerMargLik ResNet34LaplaceLastLayerGridSearch'
                ;;
            1)
                seeds='26981'
                models='ResNet34 ResNet34TemperatureScaling ResNet34LaplaceLastLayerMargLik ResNet34LaplaceLastLayerGridSearch'
                ;;
            2)
                seeds='156685'
                models='ResNet34 ResNet34TemperatureScaling ResNet34LaplaceLastLayerMargLik ResNet34LaplaceLastLayerGridSearch'
                ;;
            3)
                seeds='296144'
                models='ResNet34 ResNet34TemperatureScaling ResNet34LaplaceLastLayerMargLik ResNet34LaplaceLastLayerGridSearch'
                ;;
            4)
                seeds='572116'
                models='ResNet34 ResNet34TemperatureScaling ResNet34LaplaceLastLayerMargLik ResNet34LaplaceLastLayerGridSearch'
                ;;
            10)
                seeds='321587'
                models='ResNet34WeightSpaceVIDiagonal'
                ;;
            11)
                seeds='26981'
                models='ResNet34WeightSpaceVIDiagonal'
                ;;
            12)
                seeds='156685'
                models='ResNet34WeightSpaceVIDiagonal'
                ;;
            13)
                seeds='296144'
                models='ResNet34WeightSpaceVIDiagonal'
                ;;
            14)
                seeds='572116'
                models='ResNet34WeightSpaceVIDiagonal'
                ;;
            20)
                seeds='321587'
                models='ResNet34IVILowRank'
                ;;
            21)
                seeds='26981'
                models='ResNet34IVILowRank'
                ;;
            22)
                seeds='156685'
                models='ResNet34IVILowRank'
                ;;
            23)
                seeds='296144'
                models='ResNet34IVILowRank'
                ;;
            24)
                seeds='572116'
                models='ResNet34IVILowRank'
                ;;
            30)
                seeds='321587'
                models='ResNet34IVIKronecker'
                ;;
            31)
                seeds='26981'
                models='ResNet34IVIKronecker'
                ;;
            32)
                seeds='156685'
                models='ResNet34IVIKronecker'
                ;;
            33)
                seeds='296144'
                models='ResNet34IVIKronecker'
                ;;
            34)
                seeds='572116'
                models='ResNet34IVIKronecker'
                ;;
            40)
                seeds='321587'
                models='ResNet34SWAG'
                ;;
            41)
                seeds='26981'
                models='ResNet34SWAG'
                ;;
            42)
                seeds='156685'
                models='ResNet34SWAG'
                ;;
            43)
                seeds='296144'
                models='ResNet34SWAG'
                ;;
            44)
                seeds='572116'
                models='ResNet34SWAG'
                ;;
            99)
                seeds='321587 26981 156685 296144 572116'
                models='ResNet34Ensemble'
                ;;
            *)
                seeds='321587 26981 156685 296144 572116'
                models='ResNet34TemperatureScaling'
                ;;
        esac
        ;;   
    CIFAR100)
        ood_dataset='CIFAR100C'
        pretrained=True
        freeze_pretrained_weights=False
        lr=0.005
        momentum=0.9
        batch_size=128
        batch_size_test=128
        max_epochs=200
        case $SLURM_ARRAY_TASK_ID in
            0)
                seeds='321587'
                models='ResNet34 ResNet34TemperatureScaling ResNet34LaplaceLastLayerMargLik ResNet34LaplaceLastLayerGridSearch'
                ;;
            1)
                seeds='26981'
                models='ResNet34 ResNet34TemperatureScaling ResNet34LaplaceLastLayerMargLik ResNet34LaplaceLastLayerGridSearch'
                ;;
            2)
                seeds='156685'
                models='ResNet34 ResNet34TemperatureScaling ResNet34LaplaceLastLayerMargLik ResNet34LaplaceLastLayerGridSearch'
                ;;
            3)
                seeds='296144'
                models='ResNet34 ResNet34TemperatureScaling ResNet34LaplaceLastLayerMargLik ResNet34LaplaceLastLayerGridSearch'
                ;;
            4)
                seeds='572116'
                models='ResNet34 ResNet34TemperatureScaling ResNet34LaplaceLastLayerMargLik ResNet34LaplaceLastLayerGridSearch'
                ;;
            10)
                seeds='321587'
                models='ResNet34WeightSpaceVIDiagonal'
                ;;
            11)
                seeds='26981'
                models='ResNet34WeightSpaceVIDiagonal'
                ;;
            12)
                seeds='156685'
                models='ResNet34WeightSpaceVIDiagonal'
                ;;
            13)
                seeds='296144'
                models='ResNet34WeightSpaceVIDiagonal'
                ;;
            14)
                seeds='572116'
                models='ResNet34WeightSpaceVIDiagonal'
                ;;
            20)
                seeds='321587'
                models='ResNet34IVILowRank'
                ;;
            21)
                seeds='26981'
                models='ResNet34IVILowRank'
                ;;
            22)
                seeds='156685'
                models='ResNet34IVILowRank'
                ;;
            23)
                seeds='296144'
                models='ResNet34IVILowRank'
                ;;
            24)
                seeds='572116'
                models='ResNet34IVILowRank'
                ;;
            40)
                seeds='321587'
                models='ResNet34SWAG'
                ;;
            41)
                seeds='26981'
                models='ResNet34SWAG'
                ;;
            42)
                seeds='156685'
                models='ResNet34SWAG'
                ;;
            43)
                seeds='296144'
                models='ResNet34SWAG'
                ;;
            44)
                seeds='572116'
                models='ResNet34SWAG'
                ;;
            99)
                seeds='321587 26981 156685 296144 572116'
                models='ResNet34Ensemble'
                ;;
            *)
                seeds='321587 26981 156685 296144 572116'
                models=''
                ;;
        esac
        ;;  
    TinyImageNet)
        ood_dataset='TinyImageNetC'
        pretrained=True
        freeze_pretrained_weights=False
        lr=0.005
        momentum=0.9
        batch_size=128
        batch_size_test=128
        max_epochs=200
        case $SLURM_ARRAY_TASK_ID in
            0)
                seeds='321587'
                models='ResNet34 ResNet34TemperatureScaling ResNet34LaplaceLastLayerMargLik ResNet34LaplaceLastLayerGridSearch'
                ;;
            1)
                seeds='26981'
                models='ResNet34 ResNet34TemperatureScaling ResNet34LaplaceLastLayerMargLik ResNet34LaplaceLastLayerGridSearch'
                ;;
            2)
                seeds='156685'
                models='ResNet34 ResNet34TemperatureScaling ResNet34LaplaceLastLayerMargLik ResNet34LaplaceLastLayerGridSearch'
                ;;
            3)
                seeds='296144'
                models='ResNet34 ResNet34TemperatureScaling ResNet34LaplaceLastLayerMargLik ResNet34LaplaceLastLayerGridSearch'
                ;;
            4)
                seeds='572116'
                models='ResNet34 ResNet34TemperatureScaling ResNet34LaplaceLastLayerMargLik ResNet34LaplaceLastLayerGridSearch'
                ;;
            10)
                seeds='321587'
                models='ResNet34WeightSpaceVIDiagonal'
                ;;
            11)
                seeds='26981'
                models='ResNet34WeightSpaceVIDiagonal'
                ;;
            12)
                seeds='156685'
                models='ResNet34WeightSpaceVIDiagonal'
                ;;
            13)
                seeds='296144'
                models='ResNet34WeightSpaceVIDiagonal'
                ;;
            14)
                seeds='572116'
                models='ResNet34WeightSpaceVIDiagonal'
                ;;
            20)
                seeds='321587'
                models='ResNet34IVILowRank'
                ;;
            21)
                seeds='26981'
                models='ResNet34IVILowRank'
                ;;
            22)
                seeds='156685'
                models='ResNet34IVILowRank'
                ;;
            23)
                seeds='296144'
                models='ResNet34IVILowRank'
                ;;
            24)
                seeds='572116'
                models='ResNet34IVILowRank'
                ;;
            40)
                seeds='321587'
                models='ResNet34SWAG'
                ;;
            41)
                seeds='26981'
                models='ResNet34SWAG'
                ;;
            42)
                seeds='156685'
                models='ResNet34SWAG'
                ;;
            43)
                seeds='296144'
                models='ResNet34SWAG'
                ;;
            44)
                seeds='572116'
                models='ResNet34SWAG'
                ;;
            99)
                seeds='321587 26981 156685 296144 572116'
                models='ResNet34Ensemble'
                ;;
            100)
                seeds='321587'
                models='ResNet34LaplaceLastLayerMargLik'
                ;;
            101)
                seeds='26981'
                models='ResNet34LaplaceLastLayerMargLik'
                ;;
            102)
                seeds='156685'
                models='ResNet34LaplaceLastLayerMargLik'
                ;;
            103)
                seeds='296144'
                models='ResNet34LaplaceLastLayerMargLik'
                ;;
            104)
                seeds='572116'
                models='ResNet34LaplaceLastLayerMargLik'
                ;;
            *)
                seeds='321587 26981 156685 296144 572116'
                models=''
                ;;
        esac
        ;; 
    *)
    pretrained=False
    freeze_pretrained_weights=False
    ood_dataset=None
    lr=0.1
    batch_size=1024
    ;;
esac

for model in $models
do
    case "$model" in
        *Laplace*)
            epochs=0
            tune_lr=False
            tune_batch_size=False
            ;;   
        *Ensemble*)
            epochs=0
            tune_lr=False
            tune_batch_size=False
            ;;
        *TemperatureScaling*)
            epochs=0
            tune_lr=False
            tune_batch_size=False
            ;;  
        *)
        epochs=$max_epochs
        tune_batch_size=False
        tune_lr=False
        ;;
    esac
    for parametrization in $parametrizations
    do
        for seed in $seeds
        do
            python train.py \
            --seed $seed \
            --dataset $dataset \
            --ood_dataset $ood_dataset \
            --model $model \
            --parametrization $parametrization \
            --pretrained $pretrained \
            --freeze_pretrained_weights $freeze_pretrained_weights \
            --max_epochs $epochs \
            --lr $lr \
            --momentum $momentum \
            --batch_size $batch_size \
            --batch_size_test $batch_size_test \
            --num_workers $num_workers \
            --pin_memory $pin_memory \
            --tune_lr $tune_lr \
            --tune_batch_size $tune_batch_size
        done
    done
done