data:
  dim_z: 3 # True latent dim
  dim_x: 4 # Observed dim
  num_samps: 5000 # Number of samples for each environment
  base_dist:
    type: "gmm"
    n_comps: 3 # Number of components
  envs:
    num_concepts: 2
    type: "normal"
    n_envs: 3 # Num envs
    signs: [0, 1, 1]
    means: [0.2, -0.2, 0.1]
    var: 0.005
  f_type: "linear" # type of mapping

model:
  type: "contrastive_linear" # Type of model
  hidden_layers: 2 # Num hidden layers
  hidden_dim: 32 # Hidden layer sizes

train:
  batch_size: 2048
  restarts: 0   # nr of restarts from fresh initialisation, final model is the one with minimal validation loss
  epochs: 100
  device: cuda:0
  eta: 0.0001 # l1 penalty
  optimizer: Adam
  lr_parametric: .05
  lr_nonparametric: .005
  weight_decay: 0.00