#!/bin/bash

conda activate mvvae

# local wandb instance
wandb_entity="ADD HERE"
project_name="ADD HERE"
dir_experiments="ADD HERE"
dir_data="ADD HERE"
dir_clf="ADD HERE"
dir_alphabet="ADD HERE"
logdir="${dir_experiments}/logs/CelebA"

device="cuda"                            # 'cuda' if you are useing a GPU
models=("unimodal" "joint" "mixedprior") # "joint" or "mixedprior" or "unimodal"
dataset_names=("CelebA")
seeds=(1 2 3 4 5)
betas=(0.25 0.5 1.0 2.0 4.0)
betas=(1.0)
gammas=(0.0001)
latent_dims=(128)
drpm_prior=(False)
alpha_annealing=(True)
alpha_annealing_n_steps=(150000)
n_epochs=(400)
learning_rates=(5e-4)
batch_sizes=(256)
log_freq_downstream=50
log_freq_coherence=50
log_freq_lhood=500
log_freq_plotting=50

for dataset in ${dataset_names[@]}; do
	for model in ${models[@]}; do
		for seed in ${seeds[@]}; do
			for beta in ${betas[@]}; do
				for gamma in ${gammas[@]}; do
					for ld in ${latent_dims[@]}; do
						for n_ep in ${n_epochs[@]}; do
							for dp in ${drpm_prior[@]}; do
								for aa in ${alpha_annealing[@]}; do
									for aa_n_steps in ${alpha_annealing_n_steps[@]}; do
										for n_e in ${n_epochs[@]}; do
											for l_r in ${learning_rates[@]}; do
												for bs in ${batch_sizes[@]}; do
													run_name=""
													wandb_logdir=${logdir}
													mkdir -p ${wandb_logdir}
													python main_mv_wsl.py \
														model="${model}" \
														++model.device=${device} \
														++model.seed=${seed} \
														++model.epochs=${n_e} \
														++model.batch_size=${bs} \
														++model.beta=${beta} \
														++model.gamma=${gamma} \
														++model.latent_dim=${ld} \
														++model.drpm_prior=${dp} \
														++model.alpha_annealing=${aa} \
														++model.alpha_annealing_steps=${aa_n_steps} \
														++model.lr=${l_r} \
														++model.epochs=${n_ep} \
														dataset=${dataset} \
														++dataset.dir_data=${dir_data} \
														++dataset.dir_clf=${dir_clf} \
														++dataset.dir_alphabet=${dir_alphabet} \
														++log.downstream_logging_frequency=${log_freq_downstream} \
														++log.coherence_logging_frequency=${log_freq_coherence} \
														++log.likelihood_logging_frequency=${log_freq_lhood} \
														++log.img_plotting_frequency=${log_freq_plotting} \
														++log.wandb_offline="False" \
														++log.wandb_local_instance="True" \
														++log.wandb_entity=${wandb_entity} \
														++log.wandb_run_name=${run_name} \
														++log.wandb_group="20240216" \
														++log.wandb_project_name=${project_name} \
														++log.dir_logs=${wandb_logdir}
												done
											done
										done
									done
								done
							done
						done
					done
				done
			done
		done
	done
done
