dataset: lowrankmlp-classification
model: known-latent

task:
  _target_: tasks.classification.ClassificationICL
  lr: 1e-4
  no_plot: True
  model:
    _target_: models.explicit.ExplicitModelWith
    context_model:
      _target_: models.explicit.KnownLatent
      z_dim: ${....data.kind_kwargs.low_dim}
    prediction_model:
      _target_: models.explicit.TransformerPrediction
      x_dim: ${....data.x_dim}
      y_dim: ${....data.y_dim}
      z_dim: ${....data.kind_kwargs.low_dim}
      n_features: 256
      n_heads: 4
      n_hidden: 512
      n_layers: 4

data:
  _target_: data.classification.ClassificationDataModule
  kind: low_rank_mlp   
  x_dim: 2
  y_dim: 2
  kind_kwargs:
    low_dim: 10
    hidden_dim: 64
    layers: 1
  min_context: 16
  max_context: 128
  batch_size: 128
  train_size: 1000
  val_size: 1000
  temperature: 0.1
  context_style: "same"
  ood_styles: ["far", "wide"]

trainer:
  max_epochs: 1000

callbacks: