#!/bin/bash

export PYTHONPATH=$PYTHONPATH:.

# This contains sections on
# (a) Fusion of 2, 4, 6 models into different architecture (MLPLarge)
# (b) Distillation of MLPNet into MLPSmall
# Uncomment individual sections to run the corresponding experiments.

################## FUSION MULTIPLE MODELS INTO MLPLarge ###################

base_model_prefix="mlp_sgd_models_layer_3"
layer=3

################ WB Fusion ###################

#### Fuse 2 models

# seed=(847 53 43 348 437 82 233 31 786 234)
# for reg in 0.0005
# do
#   for idx in {0..8}
#   do
#     fusion_type="tlp"
#     tlp_cost_choice="weight"
#     seed1=${seed[($idx % 10)]}
#     seed2=${seed[(($idx+1) % 10)]}
#     python src/tlp_model_fusion/fuse_models.py \
#                   --experiment_name "${prefix}_random_init_large_${fusion_type}" \
#                   --dataset_name 'MNISTNorm' \
#                   --batch_size 128 \
#                   --model_name 'FC' \
#                   --input_dim 784 \
#                   --hidden_dims 800 400 200 \
#                   --target_diff_architecture \
#                   --output_dim 10 \
#                   --num_epochs 20 \
#                   --seed "43" \
#                   --gpu_ids "0" \
#                   --fusion_type "$fusion_type" \
#                   --activation_batch_size 256 \
#                   --use_pre_activations \
#                   --tlp_cost_choice "${tlp_cost_choice}" \
#                   --tlp_ot_solver 'sinkhorn' \
#                   --tlp_sinkhorn_regularization "${reg}" \
#                   --model_path_list \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth"
#   done
# done


#### Fuse 4 models

# seed=(847 53 43 348 437 82 233 31 786 234)
# for reg in 0.0005
# do
#   for idx in {0..8}
#   do
#     fusion_type="tlp"
#     tlp_cost_choice="weight"
#     seed1=${seed[($idx % 10)]}
#     seed2=${seed[(($idx+1) % 10)]}
#     seed3=${seed[(($idx+2) % 10)]}
#     seed4=${seed[(($idx+3) % 10)]}
#     python src/tlp_model_fusion/fuse_models.py \
#                   --experiment_name "${prefix}_random_init_large_${fusion_type}" \
#                   --dataset_name 'MNISTNorm' \
#                   --batch_size 128 \
#                   --model_name 'FC' \
#                   --input_dim 784 \
#                   --hidden_dims 800 400 200 \
#                   --target_diff_architecture \
#                   --output_dim 10 \
#                   --num_epochs 20 \
#                   --seed "43" \
#                   --gpu_ids "0" \
#                   --fusion_type "$fusion_type" \
#                   --activation_batch_size 256 \
#                   --use_pre_activations \
#                   --tlp_cost_choice "${tlp_cost_choice}" \
#                   --tlp_ot_solver 'sinkhorn' \
#                   --tlp_sinkhorn_regularization "${reg}" \
#                   --model_path_list \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed3}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed4}/snapshots/best_val_acc_model.pth" 
#   done
# done


#### Fuse 6 models

# seed=(847 53 43 348 437 82 233 31 786 234)
# for reg in 0.0005
# do
#   for idx in {0..8}
#   do
#     fusion_type="tlp"
#     tlp_cost_choice="weight"
#     seed1=${seed[($idx % 10)]}
#     seed2=${seed[(($idx+1) % 10)]}
#     seed3=${seed[(($idx+2) % 10)]}
#     seed4=${seed[(($idx+3) % 10)]}
#     seed5=${seed[(($idx+4) % 10)]}
#     seed6=${seed[(($idx+5) % 10)]}
#     python src/tlp_model_fusion/fuse_models.py \
#                   --experiment_name "${prefix}_random_init_large_${fusion_type}" \
#                   --dataset_name 'MNISTNorm' \
#                   --batch_size 128 \
#                   --model_name 'FC' \
#                   --input_dim 784 \
#                   --hidden_dims 800 400 200 \
#                   --target_diff_architecture \
#                   --output_dim 10 \
#                   --num_epochs 20 \
#                   --seed "43" \
#                   --gpu_ids "0" \
#                   --fusion_type "$fusion_type" \
#                   --activation_batch_size 256 \
#                   --use_pre_activations \
#                   --tlp_cost_choice "${tlp_cost_choice}" \
#                   --tlp_ot_solver 'sinkhorn' \
#                   --tlp_sinkhorn_regularization "${reg}" \
#                   --model_path_list \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed3}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed4}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed5}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed6}/snapshots/best_val_acc_model.pth" 
#   done
# done



################# OT Fusion ####################

##### Fuse 2 models

# seed=(847 53 43 348 437 82 233 31 786 234)
# for reg in 0.001
# do
#   for idx in {0..8}
#   do
#     fusion_type="ot"
#     ot_cost_choice="weight"
#     seed1=${seed[($idx % 10)]}
#     seed2=${seed[(($idx+1) % 10)]}
#     python src/tlp_model_fusion/fuse_models.py \
#                   --experiment_name "${prefix}_random_init_large_${fusion_type}" \
#                   --dataset_name 'MNISTNorm' \
#                   --batch_size 128 \
#                   --model_name 'FC' \
#                   --input_dim 784 \
#                   --hidden_dims 800 400 200 \
#                   --target_diff_architecture \
#                   --output_dim 10 \
#                   --num_epochs 20 \
#                   --seed "43" \
#                   --gpu_ids "0" \
#                   --fusion_type "$fusion_type" \
#                   --activation_batch_size 256 \
#                   --use_pre_activations \
#                   --ad_hoc_cost_choice "${ot_cost_choice}" \
#                   --ad_hoc_ot_solver 'sinkhorn' \
#                   --ad_hoc_sinkhorn_regularization "${reg}" \
#                   --model_path_list \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth"
#   done
# done


##### Fuse 4 models

# seed=(847 53 43 348 437 82 233 31 786 234)
# for reg in 0.0005
# do
#   for idx in {0..8}
#   do
#     fusion_type="ot"
#     ot_cost_choice="weight"
#     seed1=${seed[($idx % 10)]}
#     seed2=${seed[(($idx+1) % 10)]}
#     seed3=${seed[(($idx+2) % 10)]}
#     seed4=${seed[(($idx+3) % 10)]}
#     python src/tlp_model_fusion/fuse_models.py \
#                   --experiment_name "${prefix}_random_init_large_${fusion_type}" \
#                   --dataset_name 'MNISTNorm' \
#                   --batch_size 128 \
#                   --model_name 'FC' \
#                   --input_dim 784 \
#                   --hidden_dims 800 400 200 \
#                   --target_diff_architecture \
#                   --output_dim 10 \
#                   --num_epochs 20 \
#                   --seed "43" \
#                   --gpu_ids "0" \
#                   --fusion_type "$fusion_type" \
#                   --activation_batch_size 256 \
#                   --use_pre_activations \
#                   --ad_hoc_cost_choice "${ot_cost_choice}" \
#                   --ad_hoc_ot_solver 'sinkhorn' \
#                   --ad_hoc_sinkhorn_regularization "${reg}" \
#                   --model_path_list \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed3}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed4}/snapshots/best_val_acc_model.pth"
#   done
# done


##### Fuse 6 models

# seed=(847 53 43 348 437 82 233 31 786 234)
# for reg in 0.0005
# do
#   for idx in {0..8}
#   do
#     fusion_type="ot"
#     ot_cost_choice="weight"
#     seed1=${seed[($idx % 10)]}
#     seed2=${seed[(($idx+1) % 10)]}
#     seed3=${seed[(($idx+2) % 10)]}
#     seed4=${seed[(($idx+3) % 10)]}
#     seed5=${seed[(($idx+4) % 10)]}
#     seed6=${seed[(($idx+5) % 10)]}
#     python src/tlp_model_fusion/fuse_models.py \
#                   --experiment_name "${prefix}_random_init_large_${fusion_type}" \
#                   --dataset_name 'MNISTNorm' \
#                   --batch_size 128 \
#                   --model_name 'FC' \
#                   --input_dim 784 \
#                   --hidden_dims 800 400 200 \
#                   --target_diff_architecture \
#                   --output_dim 10 \
#                   --num_epochs 20 \
#                   --seed "43" \
#                   --gpu_ids "0" \
#                   --fusion_type "$fusion_type" \
#                   --activation_batch_size 256 \
#                   --use_pre_activations \
#                   --ad_hoc_cost_choice "${ot_cost_choice}" \
#                   --ad_hoc_ot_solver 'sinkhorn' \
#                   --ad_hoc_sinkhorn_regularization "${reg}" \
#                   --model_path_list \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed3}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed4}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed5}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed6}/snapshots/best_val_acc_model.pth"
#   done
# done



################ DISTILLATION INTO MLPSmall #################

# This code does one shot distillation of MLPNet into MLPSmall type model.
# Note down the results as each experiment for each hyperparameter finishes.

#### WB Fusion ####

# prefix="${base_model_prefix}_one_shot_distill"
# for seed in 437 348 233 82 31
# do
#   fusion_type="tlp"
#   tlp_cost_choice="weight"
#   for reg in 0.01 0.005 0.001 0.0005
#   do
#     python src/tlp_model_fusion/fuse_models.py \
#                   --experiment_name "${prefix}_small_${fusion_type}_${seed}" \
#                   --dataset_name 'MNISTNorm' \
#                   --target_diff_architecture \
#                   --batch_size 128 \
#                   --model_name 'FC' \
#                   --input_dim 784 \
#                   --hidden_dims 200 100 50 \
#                   --output_dim 10 \
#                   --num_epochs 20 \
#                   --seed "43" \
#                   --gpu_ids "0" \
#                   --fusion_type "$fusion_type" \
#                   --activation_batch_size 256 \
#                   --use_pre_activations \
#                   --tlp_cost_choice "${tlp_cost_choice}" \
#                   --tlp_ot_solver 'sinkhorn' \
#                   --tlp_sinkhorn_regularization "${reg}" \
#                   --model_path_list \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed}/snapshots/best_val_acc_model.pth"
#   done
# done



#### OT Fusion ####

# fusion_type="ot"
# ad_hoc_cost_choice="weight"
# for seed in 437 348 233 82 31
# do
#   for reg in 0.01 0.005 0.001 0.0005
#   do
#     python src/tlp_model_fusion/fuse_models.py \
#                   --experiment_name "${prefix}_small_${fusion_type}_${seed}" \
#                   --dataset_name 'MNISTNorm' \
#                   --batch_size 128 \
#                   --target_diff_architecture \
#                   --model_name 'FC' \
#                   --input_dim 784 \
#                   --hidden_dims 200 100 50 \
#                   --output_dim 10 \
#                   --num_epochs 20 \
#                   --seed "43" \
#                   --gpu_ids "0" \
#                   --fusion_type "$fusion_type" \
#                   --activation_batch_size 256 \
#                   --use_pre_activations \
#                   --ad_hoc_cost_choice "${ad_hoc_cost_choice}" \
#                   --ad_hoc_ot_solver 'sinkhorn' \
#                   --ad_hoc_sinkhorn_regularization "${reg}" \
#                   --model_path_list \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed}/snapshots/best_val_acc_model.pth"
#   done
# done

