dataset: raven
model: explicit-aux-transformer

task:
  _target_: tasks.raven.RavenICL
  lr: 1e-4
  embedding_dim: 256
  backprop_rule_pred: True
  model:
    _target_: models.explicit.ExplicitModelWith
    context_model:
      _target_: models.explicit.RavenTransformerContext
      dim: ${...embedding_dim}
      n_heads: 4
      n_hidden: 512
      n_layers: 4
      dropout: 0.1
      z_dim: ${...embedding_dim}
    prediction_model:
      _target_: models.explicit.RavenTransformerPrediction
      dim: ${...embedding_dim}
      n_heads: ${..context_model.n_heads}
      n_hidden: ${..context_model.n_hidden}
      n_layers: ${..context_model.n_layers}
      dropout: ${..context_model.dropout}
      z_dim: ${..context_model.z_dim}

data:
  _target_: data.raven.RavenDataModule
  data_dir: "~/data/raven"
  setting: "inpo"
  n_values: 40
  batch_size: 512
  num_workers: 0 

trainer:
  max_epochs: 1000

callbacks:
  # - _target_: pytorch_lightning.callbacks.EarlyStopping
  #   monitor: "val_iid/loss"
  #   mode: "min"
  #   patience: 100
  # - _target_: pytorch_lightning.callbacks.ModelCheckpoint
  #   monitor: "val_iid/loss"
  #   mode: "min"
  #   save_last: True
  #   dirpath: ${save_dir}/checkpoints/${logger.name}
  #   filename: "epoch={epoch:03d}"
  #   auto_insert_metric_name: True
