include: cfg/model/base.yaml

model: SbmclVae

encoder: CnnEncoder
enc_args:
  output_shape: [512]
  output_activation: relu

sbmcl_mlp_args:
  hidden_dim: 512
  layers: 2
  output_activation: none

vae_mlp_args:
  hidden_dim: 512
  layers: 2
  output_activation: none

xz_enc_args:
  # input_shape and output_shape will be determined in the code
  hidden_dim: 512
  layers: 2
  output_activation: relu

decoder: CnnDecoder
dec_args:
  input_shape: [512]
  output_activation: none

z_dim: 512
latent_dim: 512
map: False
train_recon: True
sbmcl_kl: True

eval_z_samples: 5
eval_latent_samples: 10
kl_warmup: 1000

optim_args:
  lr: 0.001
