dataset: raven
model: implicit

task:
  _target_: tasks.raven.RavenICL
  lr: 1e-4
  embedding_dim: 256
  model:
    _target_: models.implicit.RavenTransformerImplicit
    dim: ${..embedding_dim}
    n_heads: 4
    n_hidden: 512
    n_layers: 8
    dropout: 0.1

data:
  _target_: data.raven.RavenDataModule
  data_dir: "~/data/raven"
  setting: "inpo"
  n_values: 40
  batch_size: 512
  num_workers: 0 

trainer:
  max_epochs: 1000

callbacks:
  # - _target_: pytorch_lightning.callbacks.EarlyStopping
  #   monitor: "val_iid/loss"
  #   mode: "min"
  #   patience: 100
  # - _target_: pytorch_lightning.callbacks.ModelCheckpoint
  #   monitor: "val_iid/loss"
  #   mode: "min"
  #   save_last: True
  #   dirpath: ${save_dir}/checkpoints/${logger.name}
  #   filename: "epoch={epoch:03d}"
  #   auto_insert_metric_name: True