#!/bin/bash

################################################################################
# This script is used to run a detailed hyperparameter search for the max_dist 
# parameter after noticing that best results are obtained when max_dist is
# around 1/3 of the distance between anchors. The experiment is done with a few
# values of anchor_multiplier and max_dist is ranged from 1/3 to 1/2 of the
# distance between anchors.
################################################################################

SAVE_DIR="results/ablation_details/"
nb_features=5

for anchor_multiplier in 8 10 16; do

    anchor_to_anchor=$(python3 -c "import numpy as np; print(np.sqrt($anchor_multiplier**2 * $nb_features*2))")
    
    # array of values to test
    values_to_test=($(python3 -c "import numpy as np; print(' '.join([str(i.round(3)) for i in np.linspace(1/3, 1/2, 8)*$anchor_to_anchor]))"))
    
    i=0
    while [ $i -lt ${#values_to_test[@]} ]; do
        max_dist_1=${values_to_test[$i]}

        exp_dir="$SAVE_DIR/cifar10_full/standard_vgg32/anchor_mul-$anchor_multiplier"

        mkdir -p $exp_dir
        run_file="$exp_dir/to_run-$i.sh"

        cat > $run_file << EOF
export CUDA_VISIBLE_DEVICES=0
mkdir -p $exp_dir/max_d-$max_dist_1
python3 train_model.py \
    --save_path $exp_dir/max_d-$max_dist_1 \
    --dataset cifar10_full --config 0 \
    --epochs 600 \
    --lr 0.1 \
    --model standard_vgg32 \
    --loss dist \
    --randaug_n 1 --randaug_m 6 \
    --nb_features 5 \
    --anchor_multiplier $anchor_multiplier \
    --max_dist $max_dist_1 \
    --verbose 2 \
    --osr_score min \
    --summary \
    --split_train_val > $exp_dir/max_d-$max_dist_1/run_log.txt 2>&1 &
EOF
        i=$((i+1))
        if [ $i -lt $(( ${#values_to_test[@]} )) ]; then
            max_dist_2=${values_to_test[$((i))]}
            cat >> $run_file << EOF     
export CUDA_VISIBLE_DEVICES=1
mkdir -p $exp_dir/max_d-$max_dist_2
python3 train_model.py \
    --save_path $exp_dir/max_d-$max_dist_2 \
    --dataset cifar10_full --config 0 \
    --epochs 600 \
    --lr 0.1 \
    --model standard_vgg32 \
    --loss dist \
    --randaug_n 1 --randaug_m 6 \
    --nb_features 5 \
    --anchor_multiplier $anchor_multiplier \
    --max_dist $max_dist_2 \
    --verbose 2 \
    --osr_score min \
    --summary \
    --split_train_val > $exp_dir/max_d-$max_dist_2/run_log.txt 2>&1 &
pid_2=\$!

wait \$pid_1
success_1=\$?

wait \$pid_2
success_2=\$?

exit_code=\$((\$success_1 + \$success_2))
exit \$exit_code
EOF
            i=$((i+1))
        else
            cat >> $run_file << EOF
wait \$pid_1
success_1=\$?
exit \$success_1
EOF
            i=$((i+1))
        fi

        sbatch bash_scripts/execute_bash_file.sh slurm $run_file        
    done 
done


