#!/bin/bash

################################################################################
# This script is used to run an hyperparameter search of two hyperparameters:
# anchor_multiplier and max_dist.
# All the combinations of the two hyperparameters are tested.
################################################################################

SAVE_DIR="results/ablation/"

for anchor_multiplier in 1 2 4 8 10 16; do
# for anchor_multiplier in 2; do

    # going to test for doubles too
    i=0
    for max_dist in 1 4 10 ; do
    # for max_dist in 1 4; do

        exp_dir="$SAVE_DIR/cifar10_full/standard_vgg32/anchor_mul-$anchor_multiplier"

        mkdir -p $exp_dir
        run_file="$exp_dir/to_run-$i.sh"

        cat > $run_file << EOF
export CUDA_VISIBLE_DEVICES=0
mkdir -p $exp_dir/max_d-$max_dist
python3 train_model.py \
    --save_path $exp_dir/max_d-$max_dist \
    --dataset cifar10_full --config 0 \
    --epochs 600 \
    --lr 0.1 \
    --model standard_vgg32 \
    --loss dist \
    --randaug_n 1 --randaug_m 6 \
    --nb_features 5 \
    --anchor_multiplier $anchor_multiplier \
    --max_dist $max_dist \
    --verbose 2 \
    --osr_score min \
    --summary \
    --split_train_val > $exp_dir/max_d-$max_dist/run_log.txt 2>&1 &

export CUDA_VISIBLE_DEVICES=1
mkdir -p $exp_dir/max_d-$((2*max_dist))
python3 train_model.py \
    --save_path $exp_dir/max_d-$((2*max_dist)) \
    --dataset cifar10_full --config 0 \
    --epochs 600 \
    --lr 0.1 \
    --model standard_vgg32 \
    --loss dist \
    --randaug_n 1 --randaug_m 6 \
    --nb_features 5 \
    --anchor_multiplier $anchor_multiplier \
    --max_dist $((2*max_dist)) \
    --verbose 2 \
    --osr_score min \
    --summary \
    --split_train_val > $exp_dir/max_d-$((2*max_dist))/run_log.txt 2>&1 &

wait
EOF
        
        sbatch bash_scripts/execute_bash_file.sh slurm $run_file        
        i=$((i+1))
    done
done


