dataset: gp-regression
model: explicit-mlp

task:
  _target_: tasks.regression.RegressionICL
  lr: 1e-4
  model:
    _target_: models.explicit.ExplicitModelWith
    context_model:
      _target_: models.explicit.TransformerContext
      x_dim: 1
      y_dim: 1
      n_features: 256
      n_heads: 4
      n_hidden: 512
      n_layers: 4
    prediction_model:
      _target_: models.explicit.MLPPrediction
      x_dim: 1
      y_dim: 1
      z_dim: ${..context_model.n_features}
      hidden_dim: 512

data:
  _target_: data.regression.RegressionDataModule
  kind: gp
  x_dim: 1
  y_dim: 1
  kind_kwargs:
    kernel: "RBF"
  min_context: 16
  max_context: 128
  batch_size: 128
  train_size: 1000
  val_size: 1000
  context_style: "same"
  ood_styles: ["far", "wide"]

trainer:
  max_epochs: 1000

callbacks:
  - _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: "val/iid_MSE"
    mode: "min"
    save_last: True
    dirpath: ${save_dir}/checkpoints/${logger.name}
    filename: "epoch={epoch:03d}"
    auto_insert_metric_name: True