experiment:
  description: regression w/ GCM regularization
  output_location: path/to/store/learned/weights
  load: null
  resume: false

data:
  data_key: dsprites-nonlinear

  train: data/dsprites-dataset/dsprites_train.npz
  val: data/dsprites-dataset/dsprites_val.npz
  val_ood: data/dsprites-dataset/dsprites_val.npz
  test: data/dsprites-dataset/dsprites_test.npz
  test_ood: data/dsprites-dataset/dsprites_test.npz
  target: position_y
  distractor: position_x
  nl_type: tricky
  noise: 2
  regress: True

  batch_size: 1024
  num_workers: 16

model:
  model_key: regressor
  trainer_key: gcm
  modes:
    - train
    - val
    - val_ood
    - test
    - test_ood
  epochs: 200
  patience: 25
  lamda: 0.1
  ridge_lambda: 1.0
  kernel_ft:
    gaussian:
        sigma2: 0.01
  kernel_y:
    gaussian:
        sigma2: 0.01
  kernel_z:
    gaussian:
        sigma2: 1
  n_last_reg_layers: 1
  zy_cov: True
  loo_cond_mean: True
  biased: True
  regression: kernel-ridge

  network:
    featurizer:
      - Conv2d:
          in_channels: 1
          out_channels: 16
          kernel_size: 3
          stride: 2
          padding: 1
      - LeakyReLU:
          inplace: True
      - Conv2d:
          in_channels: 16
          out_channels: 32
          kernel_size: 3
          stride: 2
          padding: 1
      - LeakyReLU:
          inplace: True
      - MaxPool2d:
          kernel_size: 2
          stride: 2
      - Conv2d:
          in_channels: 32
          out_channels: 64
          kernel_size: 3
          stride: 2
          padding: 1
      - LeakyReLU:
          inplace: True
      - MaxPool2d:
          kernel_size: 2
          stride: 2

    fc1:
      - Linear:
          in_features: 256
          out_features: 128
      - LeakyReLU:
          inplace: True

    fc2:
      - Linear:
          in_features: 128
          out_features: 64

    target:
      - Linear:
          in_features: 64
          out_features: 1

  optimizer:
    AdamW:
      lr: 0.0001
      weight_decay: 0.01

  scheduler:
    CosineAnnealingLR:
      T_max: 200

