# Latent and input dim
model_name=betavae
des=atomic_soft_new_soft_data_with_significantly_different_post_intervention_data
DESCRIPTION="--description=$des"


#Model settings
MODEL="--encoder=resnet50 --scm=unstructured --model=$model_name "

# Training epochs and settings
TRAINING="--epochs=400 --pretrain_epochs=10 --model_interventions_after_epoch=30 --batch_size=64 --num_workers=4 --seed=4901"

# Setting for e_norm regularizer
Z_REGULARIZE="--z_regularization_schedule=constant_linear_constant --z_regularization_schedule_initial=0.01 --z_regularization_schedule_final=0.0
              --z_regularization_schedule_initial_constant_epochs=10 --z_regularization_schedule_decay_epochs=10"

# Settings for consistency_mse regularizer
C_REGULARIZE="--consistency_regularization_schedule=constant --consistency_regularization_schedule_initial=0.01 --consistency_regularization_schedule_final=0.01
              --consistency_regularization_schedule_initial_constant_epochs=0 --consistency_regularization_schedule_decay_epochs=0"

INV_C_REGULARIZE="--inverse_consistency_regularization_schedule=constant --inverse_consistency_regularization_schedule_initial=0.01 --inverse_consistency_regularization_schedule_final=0.01
              --inverse_consistency_regularization_schedule_initial_constant_epochs=0 --inverse_consistency_regularization_schedule_decay_epochs=0"

LR_SCHEDULE="--lr_schedule=cosine --lr=0.002 --lr_schedule_minimal=1e-5 --lr_schedule_increase_period_by_factor=1 --lr_schedule_restart_every_epochs=30
              --lr_schedule_step_every_epochs=0 --lr_schedule_step_gamma=0.1"

NUM_TRAINERS=2 # number of gpus

unzip /home/sshirahm/scratch/data/procthor.zip -d $SLURM_TMPDIR

cd ..

# ------ python ------
# Dataset and log paths
DATADIR="--path_data=/$SLURM_TMPDIR/procthor/causaltriplet-thor/ --dataset=procthor --num_actions=7 --num_objects=23 --nature_seed=1"
EXPDIR="--expdir=./experiments/procthor/$model_name"
python main.py $DATADIR $DESCRIPTION $EXPDIR $LR_SCHEDULE $MODEL $TRAINING $Z_REGULARIZE $C_REGULARIZE $INV_C_REGULARIZE --dim_x=3 --dim_z=7
