include: cfg/model/base.yaml

model: MamlVae

# OML
inner_lr: 0.003
learnable_lr: True
reptile: False

latent_dim: 512

# Encoder
encoder: MamlCnnEncoder
enc_args:
  output_shape: [512]
  output_activation: relu

enc_mlp_args:
  hidden_dim: 512
  layers: 2
  output_activation: none

dec_mlp_args:
  hidden_dim: 512
  layers: 2
  output_activation: relu

# Decoder
decoder: MamlCnnDecoder
dec_args:
  input_shape: [512]
  output_activation: none

eval_latent_samples: 10
kl_warmup: 1000

optim_args:
  lr: 0.0003
