#!/usr/bin/env bash
# export OMPI_MCA_btl_vader_single_copy_mechanism=none
cd ..

use_private_SGD=0
delta=0.00001
noise_multiplier=1.0
l2_norm_clip=1.0

n_epochs=1
optimizer='adam'
verbose=1

# MNIST and Fashion-MNIST
dataset='cifar100'
private_model_type='ResNet18' 
n_rounds=150
n_clients=8
batch_size=50 

frac=0.15
cw_ratio=0.2 #0.3
n_clusters=2 #4
n_components=4
lr=0.01
n_neighbors=3
greedy_eps=0.3

cw_momentum=0.6
class_seed=22 
n_neighborss=(0 1 3 5 7)
seed=0

algorithm=('Federico')
for n_neighbors in ${n_neighborss[@]}
do     
    mpiexec -n ${n_clients} python run_exp.py \
            --dataset=${dataset} \
            --algorithm=${algorithm} \
            --n_components=${n_components} \
            --n_neighbors=${n_neighbors}\
            --seed=${seed} \
            --class_seed=${class_seed} \
            --frac=${frac} \
            --n_clusters=${n_clusters} \
            --n_clients=${n_clients} \
            --private_model_type=${private_model_type} \
            --use_private_SGD=${use_private_SGD} \
            --delta=${delta} \
            --noise_multiplier=${noise_multiplier} \
            --l2_norm_clip=${l2_norm_clip} \
            --optimizer=${optimizer} \
            --lr=${lr} \
            --n_epochs=${n_epochs} \
            --n_rounds=${n_rounds} \
            --batch_size=${batch_size} \
            --verbose=${verbose} \
            --cw_momentum=${cw_momentum} \
            --greedy_eps=${greedy_eps}\
            # --rerun\
    
    done
done

cw_momentums=(0.2 0.4 0.6 0.8 1 0)
seed=0
n_neighbors=3
for cw_momentum in ${cw_momentums[@]}
do  
    for seed in ${seeds[@]}
    do
        for algorithm in "${algorithms[@]}"
        do
            mpiexec -n ${n_clients} python run_exp.py \
                    --dataset=${dataset} \
                    --algorithm=${algorithm} \
                    --n_components=${n_components} \
                    --n_neighbors=${n_neighbors}\
                    --seed=${seed} \
                    --class_seed=${class_seed} \
                    --frac=${frac} \
                    --n_clusters=${n_clusters} \
                    --major_percent=${major_percent} \
                    --n_clients=${n_clients} \
                    --private_model_type=${private_model_type} \
                    --use_private_SGD=${use_private_SGD} \
                    --delta=${delta} \
                    --noise_multiplier=${noise_multiplier} \
                    --l2_norm_clip=${l2_norm_clip} \
                    --optimizer=${optimizer} \
                    --lr=${lr} \
                    --n_epochs=${n_epochs} \
                    --n_rounds=${n_rounds} \
                    --batch_size=${batch_size} \
                    --verbose=${verbose} \
                    --cw_ratio=${cw_ratio}\
                    --cw_momentum=${cw_momentum} \
                    --greedy_eps=${greedy_eps}\
                    --n_local_epochs=1\
                    # --rerun\
                    
        done
    done
done


python plot_hyper_params.py \
    --dataset=${dataset} \
    --n_clients=${n_clients} \
    --private_model_type=${private_model_type} \
    --use_private_SGD=${use_private_SGD} \
    --delta=${delta} \
    --noise_multiplier=${noise_multiplier} \
    --l2_norm_clip=${l2_norm_clip} \
    --n_rounds=${n_rounds} \
    --optimizer=${optimizer} \
    --lr=${lr} \
    --frac=${frac} \
    --n_clusters=${n_clusters} \
    --class_seeds ${class_seeds[@]} \
    --seeds ${seeds[@]}\
    --batch_size=${batch_size}\
    --cw_ratio=${cw_ratio}\
    --cw_momentum=${cw_momentum}\
    --greedy_eps=${greedy_eps}

