#!/bin/bash

export PYTHONPATH=$PYTHONPATH:.

# Uncomment individual sections to run the experiments.

######## Generate Grid Plane for FC NN Model Trained on MNIST Dataset ########

# base_model_prefix="mlp_sgd_models_layer_3"
# layer=3

# python src/tlp_model_fusion/plane.py \
#           --experiment_name 'visualization' \
# 	      	--dataset_name 'MNISTNorm' \
# 		      --batch_size 128 \
# 		      --model_name 'FC' \
# 		      --input_dim 784 \
# 		      --hidden_dims 400 200 100 \
# 		      --output_dim 10 \
# 		      --seed "43" \
# 		      --gpu_ids '0' \
# 		      --init_start "result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_847/snapshots/best_val_acc_model.pth" \
# 		      --init_end "result/${base_model_prefix}/FC_MNISTNorm/runs/debug_seed_53/snapshots/best_val_acc_model.pth" \
# 		      --fused_model_path "result/_model_fusion_tlp/FC_MNISTNorm/runs/fusion_tlp_num_models_2_layers_3_seed_43_cost_choice_weight_solver_sinkhorn_preact_init_identity_model_0_reg_0.002/snapshots/fused_model.pth" \
#           --permuted_model_1_path "result/_model_fusion_tlp/FC_MNISTNorm/runs/fusion_tlp_num_models_2_layers_3_seed_43_cost_choice_weight_solver_sinkhorn_preact_init_identity_model_0_reg_0.002/snapshots/permuted_model_1.pth" \
#           --permuted_model_2_path "result/_model_fusion_tlp/FC_MNISTNorm/runs/fusion_tlp_num_models_2_layers_3_seed_43_cost_choice_weight_solver_sinkhorn_preact_init_identity_model_0_reg_0.002/snapshots/permuted_model_2.pth" \
#           --grid_points 21 \
#           --margin_left 0.2 \
#           --margin_right 0.2 \
#           --margin_top 0.2 \
#           --margin_bottom 0.2



######### Generate Grid Plane for VGG11 Model trained on CIFAR10 Dataset #########

# base_model_prefix="deepcnn"

# python src/tlp_model_fusion/plane.py \
#           --experiment_name 'visualization' \
# 	      	--dataset_name 'CIFAR10' \
# 		      --batch_size 128 \
# 		      --model_name 'vgg11' \
# 		      --output_dim 10 \
# 		      --seed "43" \
# 		      --gpu_ids '0' \
# 		      --init_start "result/${base_model_prefix}/vgg11_CIFAR10/runs/debug_seed_786/snapshots/best_val_acc_model.pth" \
# 		      --init_end "result/${base_model_prefix}/vgg11_CIFAR10/runs/debug_seed_234/snapshots/best_val_acc_model.pth" \
# 		      --fused_model_path "result/vgg11_cnn_model_fusion_tlp/vgg11_CIFAR10/runs/fusion_tlp_num_models_2_layers_0_seed_43_cost_choice_weight_solver_sinkhorn_init_identity_model_0_reg_0.0005/snapshots/fused_model.pth"\
#           --permuted_model_1_path "result/vgg11_cnn_model_fusion_tlp/vgg11_CIFAR10/runs/fusion_tlp_num_models_2_layers_0_seed_43_cost_choice_weight_solver_sinkhorn_init_identity_model_0_reg_0.0005/snapshots/permuted_model_1.pth" \
#           --permuted_model_2_path "result/vgg11_cnn_model_fusion_tlp/vgg11_CIFAR10/runs/fusion_tlp_num_models_2_layers_0_seed_43_cost_choice_weight_solver_sinkhorn_init_identity_model_0_reg_0.0005/snapshots/permuted_model_2.pth" \
#           --grid_points 21 \
#           --margin_left 0.2 \
#           --margin_right 0.2 \
#           --margin_top 0.2 \
#           --margin_bottom 0.2


######### Generate Grid Plane for ResNet18 Model Trained CIFAR10 Dataset #########

# base_model_prefix="resnet18_nmp_v1"

# python src/tlp_model_fusion/plane.py \
#           --experiment_name 'visualization' \
# 	      	--dataset_name 'CIFAR10' \
# 		      --batch_size 128 \
# 		      --model_name 'resnet18' \
# 		      --output_dim 10 \
# 		      --seed "43" \
# 		      --gpu_ids '0' \
# 		      --init_start "result/resnet18_nmp_v1/resnet18_CIFAR10/runs/debug_seed_847/snapshots/best_val_acc_model.pth" \
# 		      --init_end "result/resnet18_nmp_v1/resnet18_CIFAR10/runs/debug_seed_43/snapshots/best_val_acc_model.pth" \
# 		      --fused_model_path "result/resnet18_nmp_v1_fusion_model_fusion_tlp/resnet18_CIFAR10/runs/fusion_tlp_num_models_2_layers_0_seed_43_cost_choice_weight_solver_sinkhorn_preact_init_identity_model_0_reg_0.0005_skip_conn_pre/snapshots/fused_model.pth"\
#           --permuted_model_1_path "result/resnet18_nmp_v1_fusion_model_fusion_tlp/resnet18_CIFAR10/runs/fusion_tlp_num_models_2_layers_0_seed_43_cost_choice_weight_solver_sinkhorn_preact_init_identity_model_0_reg_0.0005_skip_conn_pre/snapshots/permuted_model_1.pth" \
#           --permuted_model_2_path "result/resnet18_nmp_v1_fusion_model_fusion_tlp/resnet18_CIFAR10/runs/fusion_tlp_num_models_2_layers_0_seed_43_cost_choice_weight_solver_sinkhorn_preact_init_identity_model_0_reg_0.0005_skip_conn_pre/snapshots/permuted_model_2.pth" \
#           --grid_points 21 \
#           --margin_left 0.15 \
#           --margin_right 0.15 \
#           --margin_top 0.15 \
#           --margin_bottom 0.15




######### Generate Grid Plane for RNN Model Trained on MNIST Dataset #############

# python src/tlp_model_fusion/plane.py \
#           --experiment_name 'visualization' \
# 	      	--dataset_name 'SplitMNIST' \
# 		      --batch_size 64 \
# 		      --model_name 'RNN' \
# 		      --input_dim 28 \
# 		      --hidden_dims 128 \
# 		      --output_dim 10 \
# 		      --seed "43" \
# 		      --gpu_ids '0' \
# 		      --init_start "result/layer_1_dim_256_scale_1_adam/rnn_mnist/seed_847/best_val_acc_model.pth" \
# 		      --init_end "result/layer_1_dim_256_scale_1_adam/rnn_mnist/seed_53/best_val_acc_model.pth" \
# 		      --fused_model_path "result/tlp_adam/rnn_256_mnist_split/layer_1/idenitity/models_1_2_reg_0.001.pth" \
#           --permuted_model_1_path "result/tlp_adam/rnn_256_mnist_split/layer_1/idenitity/permuted_model_1.pth" \
#           --permuted_model_2_path "result/tlp_adam/rnn_256_mnist_split/layer_1/idenitity/permuted_model_2.pth" \
#           --grid_points 21 \
#           --margin_left 0.2 \
#           --margin_right 0.2 \
#           --margin_top 0.2 \
#           --margin_bottom 0.2

######### Generate Grid Plane for LSTM Model Trained on MNIST Dataset #############

# python src/tlp_model_fusion/plane.py \
#           --experiment_name 'visualization' \
# 	      	--dataset_name 'SplitMNIST' \
# 		      --batch_size 64 \
# 		      --model_name 'LSTM' \
# 		      --input_dim 28 \
# 		      --hidden_dims 128 \
# 		      --output_dim 10 \
# 		      --seed "43" \
# 		      --gpu_ids '0' \
# 		      --init_start "result/layer_1_dim_256_scale_1_adam/lstm_mnist/dataset_1/best_val_acc_model.pth" \
# 		      --init_end "result/layer_1_dim_256_scale_1_adam/lstm_mnist/dataset_2/best_val_acc_model.pth" \
# 		      --fused_model_path "result/tlp_adam/lstm_256_mnist_split/layer_1/idenitity/models_1_2_reg_0.003.pth" \
#           --permuted_model_1_path "result/tlp_adam/lstm_256_mnist_split/layer_1/idenitity/permuted_model_1.pth" \
#           --permuted_model_2_path "result/tlp_adam/lstm_256_mnist_split/layer_1/idenitity/permuted_model_2.pth" \
#           --grid_points 21 \
#           --margin_left 0.2 \
#           --margin_right 0.2 \
#           --margin_top 0.2 \
#           --margin_bottom 0.2 





######## Generate Grid Plane for RNN model Trained on AGNEWS Dataset #############

# python src/tlp_model_fusion/plane.py \
#           --experiment_name 'visualization' \
# 	      	--dataset_name 'AG_NEWS' \
# 		      --batch_size 256 \
# 		      --model_name 'RNN' \
# 		      --input_dim 100 \
# 		      --hidden_dims 256 \
# 		      --output_dim 4 \
# 		      --seed "43" \
# 		      --gpu_ids '0' \
# 		      --init_start "result/layer_1_emb_100_dim_256_adam/rnn_agnews/seed_847/best_val_acc_model.pth" \
# 		      --init_end "result/layer_1_emb_100_dim_256_adam/rnn_agnews/seed_53/best_val_acc_model.pth" \
# 		      --fused_model_path "result/rnn_agnews_emb_100_dim_256_tlp_100.0/layer_1/idenitity/models_847_53_reg_0.005.pth" \
#           --permuted_model_1_path "result/rnn_agnews_emb_100_dim_256_tlp_100.0/layer_1/idenitity/permuted_model_1.pth" \
#           --permuted_model_2_path "result/rnn_agnews_emb_100_dim_256_tlp_100.0/layer_1/idenitity/permuted_model_2.pth" \
#           --train_data_path 'data/ag_news_csv' \
#           --glove_path "data/custom_datasets/glove.6B.100d.txt" \
#           --grid_points 21 \
#           --margin_left 0.2 \
#           --margin_right 0.2 \
#           --margin_top 0.2 \
#           --margin_bottom 0.2


######## Generate Grid Plane for LSTM model Trained on AGNEWS Dataset #############

# python src/tlp_model_fusion/plane.py \
#           --experiment_name 'visualization' \
# 	      	--dataset_name 'AG_NEWS' \
# 		      --batch_size 256 \
# 		      --model_name 'LSTM' \
# 		      --input_dim 50 \
# 		      --hidden_dims 256 \
# 		      --output_dim 4 \
# 		      --seed "43" \
# 		      --gpu_ids '0' \
# 		      --init_start "result/layer_1_emb_50_dim_256_adam/lstm_agnews/seed_847/best_val_acc_model.pth" \
# 		      --init_end "result/layer_1_emb_50_dim_256_adam/lstm_agnews/seed_53/best_val_acc_model.pth" \
# 		      --fused_model_path "result/lstm_agnews_emb_50_dim_256_tlp_1000.0/layer_1/idenitity/models_847_53_reg_0.005.pth" \
#           --permuted_model_1_path "result/lstm_agnews_emb_50_dim_256_tlp_1000.0/layer_1/idenitity/permuted_model_1.pth" \
#           --permuted_model_2_path "result/lstm_agnews_emb_50_dim_256_tlp_1000.0/layer_1/idenitity/permuted_model_2.pth" \
#           --train_data_path 'data/ag_news_csv' \
#           --glove_path "data/custom_datasets/glove.6B.50d.txt" \
#           --grid_points 21 \
#           --margin_left 0.2 \
#           --margin_right 0.2 \
#           --margin_top 0.2 \
#           --margin_bottom 0.2




######## Generate Grid Plane for RNN model Trained on DBpedia Dataset #############


# python src/tlp_model_fusion/plane.py \
#           --experiment_name 'visualization' \
# 	      	--dataset_name 'DBpedia' \
# 		      --batch_size 256 \
# 		      --model_name 'RNN' \
# 		      --input_dim 100 \
# 		      --hidden_dims 256 \
# 		      --output_dim 14 \
# 		      --seed "43" \
# 		      --gpu_ids '0' \
# 		      --init_start "result/layer_1_emb_100_dim_256_adam/rnn_dbpedia/seed_847/best_val_acc_model.pth" \
# 		      --init_end "result/layer_1_emb_100_dim_256_adam/rnn_dbpedia/seed_53/best_val_acc_model.pth" \
# 		      --fused_model_path "result/rnn_dbpedia_emb_100_dim_256_tlp_1000.0/layer_1/idenitity/models_847_53_reg_0.005.pth" \
#           --permuted_model_1_path "result/rnn_dbpedia_emb_100_dim_256_tlp_1000.0/layer_1/idenitity/permuted_model_1.pth" \
#           --permuted_model_2_path "result/rnn_dbpedia_emb_100_dim_256_tlp_1000.0/layer_1/idenitity/permuted_model_2.pth" \
#           --train_data_path 'data/dbpedia_csv' \
#           --glove_path "data/custom_datasets/glove.6B.100d.txt" \
#           --grid_points 21 \
#           --margin_left 0.2 \
#           --margin_right 0.2 \
#           --margin_top 0.2 \
#           --margin_bottom 0.2




######## Generate Grid Plane for LSTM model Trained on DBpedia Dataset #############

# python src/tlp_model_fusion/plane.py \
#           --experiment_name 'visualization' \
# 	      	--dataset_name 'DBpedia' \
# 		      --batch_size 256 \
# 		      --model_name 'LSTM' \
# 		      --input_dim 50 \
# 		      --hidden_dims 256 \
# 		      --output_dim 14 \
# 		      --seed "43" \
# 		      --gpu_ids '0' \
# 		      --init_start "result/layer_1_emb_50_dim_256_adam/lstm_dbpedia/seed_847/best_val_acc_model.pth" \
# 		      --init_end "result/layer_1_emb_50_dim_256_adam/lstm_dbpedia/seed_53/best_val_acc_model.pth" \
# 		      --fused_model_path "result/lstm_dbpedia_emb_50_dim_256_tlp_100.0/layer_1/idenitity/models_847_53_reg_0.002.pth" \
#           --permuted_model_1_path "result/lstm_dbpedia_emb_50_dim_256_tlp_100.0/layer_1/idenitity/permuted_model_1.pth" \
#           --permuted_model_2_path "result/lstm_dbpedia_emb_50_dim_256_tlp_100.0/layer_1/idenitity/permuted_model_2.pth" \
#           --train_data_path 'data/dbpedia_csv' \
#           --glove_path "data/custom_datasets/glove.6B.50d.txt" \
#           --grid_points 21 \
#           --margin_left 0.2 \
#           --margin_right 0.2 \
#           --margin_top 0.2 \
#           --margin_bottom 0.2



######## Generate Grid Plane for RNN model Trained on SST Dataset #############

# python src/tlp_model_fusion/plane.py \
#           --experiment_name 'visualization' \
# 	      	--dataset_name 'SSTPT' \
# 		      --batch_size 64 \
# 		      --model_name 'RNN' \
# 		      --input_dim 50 \
# 		      --hidden_dims 256 \
# 		      --output_dim 2 \
# 		      --seed "43" \
# 		      --gpu_ids '0' \
# 		      --init_start "result/layer_1_emb_100_dim_256_adam/rnn_sst/seed_847/best_val_acc_model.pth" \
# 		      --init_end "result/layer_1_emb_100_dim_256_adam/rnn_sst/seed_53/best_val_acc_model.pth" \
# 		      --fused_model_path "result/rnn_sst_emb_100_dim_256_tlp_100.0/layer_1/idenitity/models_847_53_reg_0.005.pth" \
#           --permuted_model_1_path "result/rnn_sst_emb_100_dim_256_tlp_100.0/layer_1/idenitity/permuted_model_1.pth" \
#           --permuted_model_2_path "result/rnn_sst_emb_100_dim_256_tlp_100.0/layer_1/idenitity/permuted_model_2.pth" \
#           --glove_path "data/custom_datasets/glove.6B.100d.txt" \
#           --grid_points 21 \
#           --margin_left 0.2 \
#           --margin_right 0.2 \
#           --margin_top 0.2 \
#           --margin_bottom 0.2


######## Generate Grid Plane for LSTM model Trained on SST Dataset #############

# python src/tlp_model_fusion/plane.py \
#           --experiment_name 'visualization' \
# 	      	--dataset_name 'SSTPT' \
# 		      --batch_size 64 \
# 		      --model_name 'LSTM' \
# 		      --input_dim 50 \
# 		      --hidden_dims 256 \
# 		      --output_dim 2 \
# 		      --seed "43" \
# 		      --gpu_ids '0' \
# 		      --init_start "result/layer_1_emb_50_dim_256_adam/lstm_sst/seed_847/best_val_acc_model.pth" \
# 		      --init_end "result/layer_1_emb_50_dim_256_adam/lstm_sst/seed_53/best_val_acc_model.pth" \
# 		      --fused_model_path "result/lstm_sst_emb_50_dim_256_tlp_100.0/layer_1/idenitity/models_847_53_reg_0.002.pth" \
#           --permuted_model_1_path "result/lstm_sst_emb_50_dim_256_tlp_100.0/layer_1/idenitity/permuted_model_1.pth" \
#           --permuted_model_2_path "result/lstm_sst_emb_50_dim_256_tlp_100.0/layer_1/idenitity/permuted_model_2.pth" \
#           --glove_path "data/custom_datasets/glove.6B.50d.txt" \
#           --grid_points 21 \
#           --margin_left 0.2 \
#           --margin_right 0.2 \
#           --margin_top 0.2 \
#           --margin_bottom 0.2



