#!/bin/bash

export PYTHONPATH=$PYTHONPATH:.

################## FUSION INTO SAME ARCHITECTURE ###################

# Uncomment individual sections to run the experiment of fusion.

base_model_prefix="mlp_sgd_models_layer_3"
layer=3

################## WB Fusion #####################

### Fusing 2 models & initializing with base model 1
# seed=(847 53 43 348 437 82 233 31 786 234)

# for reg in 0.002
# do
#   for idx in {0..8}
#   do
#     seed1=${seed[($idx % 10)]}
#     seed2=${seed[(($idx+1) % 10)]}
#     fusion_type="tlp"
#     tlp_cost_choice="weight"
#     python src/tlp_model_fusion/fuse_models.py \
#                   --experiment_name "${prefix}_model_fusion_${fusion_type}" \
#                   --dataset_name 'MNISTNorm' \
#                   --batch_size 128 \
#                   --model_name 'FC' \
#                   --input_dim 784 \
#                   --hidden_dims 400 200 100 \
#                   --output_dim 10 \
#                   --num_epochs 20 \
#                   --seed "43" \
#                   --gpu_ids "0" \
#                   --fusion_type "$fusion_type" \
#                   --activation_batch_size 256 \
#                   --use_pre_activations \
#                   --tlp_cost_choice "${tlp_cost_choice}" \
#                   --tlp_init_type 'identity' \
#                   --tlp_init_model 0 \
#                   --tlp_ot_solver 'sinkhorn' \
#                   --tlp_sinkhorn_regularization "${reg}" \
#                   --model_path_list \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth" 
#   done
# done


### Fusing 2 models & initializing randomly
# seed=(847 53 43 348 437 82 233 31 786 234)

# for reg in 0.002
# do
#   for idx in {0..8}
#   do
#     seed1=${seed[($idx % 10)]}
#     seed2=${seed[(($idx+1) % 10)]}
#     fusion_type="tlp"
#     tlp_cost_choice="weight"
#     python src/tlp_model_fusion/fuse_models.py \
#                   --experiment_name "${prefix}_model_fusion_${fusion_type}" \
#                   --dataset_name 'MNISTNorm' \
#                   --batch_size 128 \
#                   --model_name 'FC' \
#                   --input_dim 784 \
#                   --hidden_dims 400 200 100 \
#                   --output_dim 10 \
#                   --num_epochs 20 \
#                   --seed "43" \
#                   --gpu_ids "0" \
#                   --fusion_type "$fusion_type" \
#                   --activation_batch_size 256 \
#                   --use_pre_activations \
#                   --tlp_cost_choice "${tlp_cost_choice}" \
#                   --tlp_ot_solver 'sinkhorn' \
#                   --tlp_sinkhorn_regularization "${reg}" \
#                   --model_path_list \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth" 
#   done
# done


################## OT Fusion #####################

#### Fusing 2 models & initializing with base model 1
# seed=(847 53 43 348 437 82 233 31 786 234)

# for reg in 0.002
# do
#   for idx in {0..8}
#   do
#     seed1=${seed[($idx % 10)]}
#     seed2=${seed[(($idx+1) % 10)]}
#     fusion_type="ot"
#     ot_cost_choice="weight"
#     python src/tlp_model_fusion/fuse_models.py \
#                   --experiment_name "${prefix}_model_fusion_${fusion_type}" \
#                   --dataset_name 'MNISTNorm' \
#                   --batch_size 128 \
#                   --model_name 'FC' \
#                   --input_dim 784 \
#                   --hidden_dims 400 200 100 \
#                   --output_dim 10 \
#                   --num_epochs 20 \
#                   --seed "43" \
#                   --gpu_ids "0" \
#                   --fusion_type "$fusion_type" \
#                   --activation_batch_size 256 \
#                   --ad_hoc_cost_choice "${ot_cost_choice}" \
#                   --use_pre_activations \
#                   --ad_hoc_init_model 0 \
#                   --ad_hoc_ot_solver 'sinkhorn' \
#                   --ad_hoc_sinkhorn_regularization "${reg}" \
#                   --model_path_list \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth"
#   done
# done

##### Fusing 2 models & initializing randomly
# seed=(847 53 43 348 437 82 233 31 786 234)

# for reg in 0.002
# do
#   for idx in {0..8}
#   do
#     seed1=${seed[($idx % 10)]}
#     seed2=${seed[(($idx+1) % 10)]}
#     fusion_type="ot"
#     ot_cost_choice="weight"
#     python src/tlp_model_fusion/fuse_models.py \
#                   --experiment_name "${prefix}_model_fusion_${fusion_type}" \
#                   --dataset_name 'MNISTNorm' \
#                   --batch_size 128 \
#                   --model_name 'FC' \
#                   --input_dim 784 \
#                   --hidden_dims 400 200 100 \
#                   --output_dim 10 \
#                   --num_epochs 20 \
#                   --seed "43" \
#                   --gpu_ids "0" \
#                   --fusion_type "$fusion_type" \
#                   --activation_batch_size 256 \
#                   --ad_hoc_cost_choice "${ot_cost_choice}" \
#                   --use_pre_activations \
#                   --ad_hoc_ot_solver 'sinkhorn' \
#                   --ad_hoc_sinkhorn_regularization "${reg}" \
#                   --model_path_list \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                   "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth"
#   done
# done


################## Vanilla Averaging Fusion #####################

### Fusing 2 models
# seed=(847 53 43 348 437 82 233 31 786 234)

# for idx in {0..8}
# do
#   fusion_type="avg"
#   seed1=${seed[($idx % 10)]}
#   seed2=${seed[(($idx+1) % 10)]}
#   python src/tlp_model_fusion/fuse_models.py \
#                 --experiment_name "${prefix}_model_fusion_${fusion_type}" \
#                 --dataset_name 'MNISTNorm' \
#                 --batch_size 128 \
#                 --model_name 'FC' \
#                 --input_dim 784 \
#                 --hidden_dims 400 200 100 \
#                 --output_dim 10 \
#                 --num_epochs 20 \
#                 --seed "43" \
#                 --gpu_ids "0" \
#                 --fusion_type "$fusion_type" \
#                 --model_path_list \
#                 "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                 "FC,result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth"
# done

