#!/bin/bash

export PYTHONPATH=$PYTHONPATH:.


base_model_prefix="resnet18_nmp_v1"
prefix="${base_model_prefix}_fusion"

# Uncomment individual section to run the experiments.

################# WB Fusion ######################

#### Fuse 2 models & initializing with base model 1
# seed=(847 53 43 348 437 82 233 31 786 234)

# for reg in 0.0005
# do
#   for idx in {0..8}
#   do
#     fusion_type="tlp"
#     tlp_cost_choice="weight"
#     seed1=${seed[($idx % 10)]}
#     seed2=${seed[(($idx+1) % 10)]}
#     python src/tlp_model_fusion/fuse_models.py \
#                   --experiment_name "${prefix}_model_fusion_${fusion_type}" \
#                   --dataset_name 'CIFAR10' \
#                   --batch_size 256 \
#                   --model_name 'resnet18' \
#                   --output_dim 10 \
#                   --num_epochs 20 \
#                   --seed "43" \
#                   --gpu_ids "0" \
#                   --fusion_type "$fusion_type" \
#                   --activation_batch_size 256 \
#                   --tlp_cost_choice "${tlp_cost_choice}" \
#                   --use_pre_activations \
#                   --tlp_init_type 'identity' \
#                   --tlp_init_model 0 \
#                   --tlp_ot_solver 'sinkhorn' \
#                   --tlp_sinkhorn_regularization "${reg}" \
#                   --resnet_skip_connection_handling "pre" \
#                   --model_path_list \
#                   "resnet18,result/${base_model_prefix}/resnet18_CIFAR10/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                   "resnet18,result/${base_model_prefix}/resnet18_CIFAR10/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth"
#   done
# done




################# OT Fusion ######################

##### Fuse 2 models & initializing with base model 1
# seed=(847 53 43 348 437 82 233 31 786 234)

# for reg in 0.0005
# do
#   for idx in {0..8}
#   do
#     fusion_type="ot"
#     tlp_cost_choice="weight"
#     seed1=${seed[($idx % 10)]}
#     seed2=${seed[(($idx+1) % 7)]}
#     python src/tlp_model_fusion/fuse_models.py \
#                   --experiment_name "${prefix}_model_fusion_${fusion_type}" \
#                   --dataset_name 'CIFAR10' \
#                   --batch_size 256 \
#                   --model_name 'resnet18' \
#                   --output_dim 10 \
#                   --num_epochs 20 \
#                   --seed "43" \
#                   --gpu_ids "0" \
#                   --fusion_type "$fusion_type" \
#                   --activation_batch_size 256 \
#                   --tlp_cost_choice "${tlp_cost_choice}" \
#                   --use_pre_activations \
#                   --ad_hoc_init_type 'identity' \
#                   --ad_hoc_init_model 0 \
#                   --ad_hoc_ot_solver 'sinkhorn' \
#                   --ad_hoc_sinkhorn_regularization "${reg}" \
#                   --resnet_skip_connection_handling "pre" \
#                   --model_path_list \
#                   "resnet18,result/${base_model_prefix}/resnet18_CIFAR10/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                   "resnet18,result/${base_model_prefix}/resnet18_CIFAR10/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth"
#   done
# done





################# Vanilla Averaging Fusion ######################
# seed=(847 53 43 348 437 82 233 31 786 234)

# for idx in {0..8}
# do
#   fusion_type="avg"
#   seed1=${seed[($idx % 10)]}
#   seed2=${seed[(($idx+1) % 10)]}
#   python src/tlp_model_fusion/fuse_models.py \
#                 --experiment_name "${prefix}_model_fusion_${fusion_type}" \
#                 --dataset_name 'CIFAR10' \
#                 --batch_size 256 \
#                 --model_name 'resnet18' \
#                 --output_dim 10 \
#                 --num_epochs 20 \
#                 --seed "43" \
#                 --gpu_ids "0" \
#                 --fusion_type "$fusion_type" \
#                 --activation_batch_size 256 \
#                 --model_path_list \
#                 "resnet18,result/${base_model_prefix}/resnet18_CIFAR10/runs/debug_seed_${seed1}/snapshots/best_val_acc_model.pth" \
#                 "resnet18,result/${base_model_prefix}/resnet18_CIFAR10/runs/debug_seed_${seed2}/snapshots/best_val_acc_model.pth"
# done