#!/bin/bash

# [int.from_bytes(os.urandom(16), "big") for _ in range(5)]
seeds='321587' 
# seeds='321587 26981 156685 296144 572116' 
# seeds='321587 26981 156685 296144 572116 19338223 66266 226372 4766 724836' 
datasets='CIFAR10' # ='MNIST CIFAR10 ImageNet'
hidden_sizes='128 256 512 1024 2048'
models='MLPIVILowRank' # MLP MLPIVILowRank MLPIVIKronecker MLPWeightSpaceVIDiagonal
case $SLURM_ARRAY_TASK_ID in
    0)
        lrs='0.00390625'
        ;;
    1)
        lrs='0.0072334'
        ;;
    2)
        lrs='0.01339444'
        ;;
    3) 
        lrs='0.02480314'
        ;;
    4) 
        lrs='0.0459292'
        ;;
    5) 
        lrs='0.08504938'
        ;;
    6) 
        lrs='0.15749013'
        ;;
    7) 
        lrs='0.29163226'
        ;;
    8) 
        lrs='0.54002987'
        ;;
    9) 
        lrs='1.0'
        ;;
    *)
        lrs='0.00390625 0.0072334  0.01339444 0.02480314 0.0459292 0.08504938 0.15749013 0.29163226 0.54002987 1.'        
        ;;
esac
parametrizations='Standard MaximalUpdate' # 'Standard MaximalUpdate'
num_workers=0
pin_memory=True

for dataset in $datasets
do
    case $dataset in
        MNIST)
            ood_dataset='FashionMNIST'
            batch_size=64
            batch_size_test=1024
            ;;  
        CIFAR10)
            ood_dataset='CIFAR10' # TEMP FOR THIS RUN
            batch_size=64
            batch_size_test=1024
            ;;   
        *)
        batch_size=64
        ;;
    esac
    for model in $models
    do
        case "$model" in
            *Laplace*)
                max_epochs=0
                tune_lr=False
                tune_batch_size=False
                ;;   
            *Ensemble*)
                max_epochs=0
                tune_lr=False
                tune_batch_size=False
                ;;  
            *)
            max_epochs=20
            tune_batch_size=False
            tune_lr=False
            ;;
        esac
        for hidden_size in $hidden_sizes
        do
            for lr in $lrs
            do
                for parametrization in $parametrizations
                do
                    for seed in $seeds
                    do
                        python train.py \
                        --seed $seed \
                        --dataset $dataset \
                        --ood_dataset $ood_dataset \
                        --model $model \
                        --hidden_size $hidden_size \
                        --parametrization $parametrization \
                        --max_epochs $max_epochs \
                        --lr $lr \
                        --batch_size $batch_size \
                        --batch_size_test $batch_size_test \
                        --bias True \
                        --momentum 0.0 \
                        --num_workers $num_workers \
                        --pin_memory $pin_memory \
                        --tune_lr $tune_lr \
                        --tune_batch_size $tune_batch_size \
                        --scale_mean_input_init_weight 16.0 \
                        --scale_mean_input_init_bias 16.0 \
                        --scale_mean_input_forward_weight 0.0625 \
                        --scale_mean_input_forward_bias 0.0625 \
                        --scale_mean_output_forward_weight 32.0 \
                        --scale_mean_output_forward_bias 32.0 \
                        --scale_mean_output_init_weight 0.0 \
                        --scale_mean_output_init_bias 0.0
                    done
                done
            done
        done
    done
done