#!/bin/bash

export PYTHONPATH=$PYTHONPATH:.

prefix="vgg11_cnn"
base_model_prefix="deepcnn"

# Uncomment individual sections to run the experiments.

############## WB Fusion ###############

#### Fusing 2 models & initialize with base model 1
# seed=(847 53 43 348 437 82 233 31 786 234)

# for reg in 0.0005
# do
#   for idx in {0..8}
#   do
#     fusion_type="tlp"
#     tlp_cost_choice="weight"
#     seed1=${seed[($idx % 10)]}
#     seed2=${seed[(($idx+1) % 10)]}
#     python src/tlp_model_fusion/fuse_models.py \
#                   --experiment_name "${prefix}_model_fusion_${fusion_type}" \
#                   --dataset_name 'CIFAR10' \
#                   --batch_size 128 \
#                   --model_name 'vgg11' \
#                   --output_dim 10 \
#                   --num_epochs 20 \
#                   --seed "43" \
#                   --gpu_ids "0" \
#                   --fusion_type "$fusion_type" \
#                   --activation_batch_size 256 \
#                   --tlp_cost_choice "${tlp_cost_choice}" \
#                   --tlp_init_type 'identity' \
#                   --tlp_init_model 0 \
#                   --tlp_ot_solver 'sinkhorn' \
#                   --tlp_sinkhorn_regularization "${reg}" \
#                   --model_path_list \
#                   "vgg11,result/${base_model_prefix}/vgg11_CIFAR10/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                   "vgg11,result/${base_model_prefix}/vgg11_CIFAR10/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth"
#   done
# done








############## OT Fusion ##################

#### Fuse 2 models & intializing with base model 1
# seed=(847 53 43 348 437 82 233 31 786 234)

# for reg in 0.0005
# do
#   for idx in {0..8}
#   do
#     fusion_type="ot"
#     tlp_cost_choice="weight"
#     seed1=${seed[($idx % 10)]}
#     seed2=${seed[(($idx+1) % 10)]}
#     python src/tlp_model_fusion/fuse_models.py \
#                   --experiment_name "${prefix}_model_fusion_${fusion_type}" \
#                   --dataset_name 'CIFAR10' \
#                   --batch_size 128 \
#                   --model_name 'vgg11' \
#                   --output_dim 10 \
#                   --num_epochs 20 \
#                   --seed "43" \
#                   --gpu_ids "0" \
#                   --fusion_type "$fusion_type" \
#                   --activation_batch_size 256 \
#                   --ad_hoc_cost_choice "${tlp_cost_choice}" \
#                   --ad_hoc_init_model 0 \
#                   --ad_hoc_ot_solver 'sinkhorn' \
#                   --ad_hoc_sinkhorn_regularization "${reg}" \
#                   --model_path_list \
#                   "vgg11,result/${base_model_prefix}/vgg11_CIFAR10/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                   "vgg11,result/${base_model_prefix}/vgg11_CIFAR10/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth"
#   done
# done







################## VANILLA AVERAGING #################

### Fuse 2 models
# seed=(847 53 43 348 437 82 233 31 786 234)

# for idx in {0..8}
# do
#   fusion_type="avg"
#   seed1=${seed[($idx % 10)]}
#   seed2=${seed[(($idx+1) % 10)]}
#   python src/tlp_model_fusion/fuse_models.py \
#                 --experiment_name "${prefix}_model_fusion_${fusion_type}" \
#                 --dataset_name 'CIFAR10' \
#                 --batch_size 128 \
#                 --model_name 'vgg11' \
#                 --output_dim 10 \
#                 --num_epochs 20 \
#                 --seed "43" \
#                 --gpu_ids "0" \
#                 --fusion_type "$fusion_type" \
#                 --activation_batch_size 256 \
#                 --model_path_list \
#                 "vgg11,result/${base_model_prefix}/vgg11_CIFAR10/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                 "vgg11,result/${base_model_prefix}/vgg11_CIFAR10/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth"
# done


