#!/bin/bash

################################################################################
# This script is used to run all experiments on Neal et al. 2018 OSR benchmark.
################################################################################

SAVE_DIR="results/benchmark_osr/"
NB_FEATURES=5
anchor_multiplier=10
LR=0.1
model="standard_vgg32"

anchor_to_anchor=$(python3 -c "import numpy as np; print(np.sqrt($anchor_multiplier**2 * $NB_FEATURES*2))")
max_dist=$(python3 -c "import numpy as np; print(np.round(0.4*$anchor_to_anchor, 3))")


for dataset in mnist svhn cifar10 cifar+10 cifar+50 tiny_imagenet; do
    if [ $dataset == "svhn" ]; then
        randaug_m=18
        lr=$LR
        image_size=32
    elif [ $dataset == "tiny_imagenet" ]; then
        randaug_m=9
        lr=0.01
        image_size=64
    else
        randaug_m=6
        lr=$LR
        image_size=32
    fi

    for loss in crossentropy dist cac; do

        if [ $loss == "crossentropy" ]; then
            osr_score="max"
            fc_end="fc_end"
            nb_features=1
        elif [ $loss == "cac" ]; then
            osr_score="min"
            fc_end="fc_end"
            nb_features=1
        else
            osr_score="min"
            fc_end="nofc_end"
            nb_features=$NB_FEATURES
        fi

        config=0
        while [ $config -lt 5 ]; do
            exp_dir="$SAVE_DIR/$model/$loss/$dataset/"
            mkdir -p $exp_dir

            run_file="$exp_dir/to_run-$config.sh"

            save_dir="$exp_dir/split-$config/nb_f-$nb_features/anchor_mul-$anchor_multiplier/max_d-$max_dist"

            cat > $run_file << EOF
export CUDA_VISIBLE_DEVICES=0
mkdir -p $save_dir
python3 train_model.py \
    --save_path $save_dir \
    --dataset $dataset --config $config \
    --epochs 600 \
    --lr $lr \
    --model $model \
    --loss $loss \
    --randaug_n 1 --randaug_m $randaug_m \
    --image_size $image_size \
    --nb_features $nb_features \
    --anchor_multiplier $anchor_multiplier \
    --max_dist $max_dist \
    --verbose 2 \
    --osr_score $osr_score \
    --batch_size 128 \
    --$fc_end \
    --summary > $save_dir/run_log.txt 2>&1 &
pid_1=\$!
EOF
        config=$((config+1))
        if [ $config -lt 5 ]; then
            save_dir="$exp_dir/split-$config/nb_f-$nb_features/anchor_mul-$anchor_multiplier/max_d-$max_dist"

            cat >> $run_file << EOF     
export CUDA_VISIBLE_DEVICES=1
mkdir -p $save_dir
python3 train_model.py \
    --save_path $save_dir \
    --dataset $dataset --config $config \
    --epochs 600 \
    --lr $lr \
    --model $model \
    --loss $loss \
    --randaug_n 1 --randaug_m $randaug_m \
    --image_size $image_size \
    --nb_features $nb_features \
    --anchor_multiplier $anchor_multiplier \
    --max_dist $max_dist \
    --verbose 2 \
    --osr_score $osr_score \
    --batch_size 128 \
    --$fc_end \
    --summary > $save_dir/run_log.txt 2>&1 &
pid_2=\$!

wait \$pid_1
success_1=\$?

wait \$pid_2
success_2=\$?

exit_code=\$((\$success_1 + \$success_2))
exit \$exit_code
EOF
            config=$((config+1))
        else
            cat >> $run_file << EOF
wait \$pid_1
success_1=\$?
exit \$success_1
EOF
        fi

        if [ $1 == "slurm" ]; then
            sbatch bash_scripts/execute_bash_file.sh slurm $run_file
        else
            bash $run_file
        fi
        
        done
    done
done