dataset: gene_crispr
model: implicit

task:
  _target_: tasks.gene_crispr.GeneCrisprICL
  lr: 1e-4
  model:
    _target_: models.implicit.TransformerImplicit
    x_dim: 5000
    y_dim: 5000
    n_features: 256
    n_heads: 4
    n_hidden: 512
    n_layers: 8
    input_has_y: False
    dropout: 0.0

data:
  _target_: data.gene_crispr.GeneCrisprDataModule
  data_path: "~/data"
  contexts_per_ptb: 50
  n_context: 40
  n_queries: 5
  query_dim_pct: 0.5
  perturb_type: "both"
  include_control: True
  batch_size: 128
  train_size: 0.8
  num_workers: 0 

trainer:
  max_epochs: 1000

callbacks:
  - _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: "val/MSE"
    mode: "min"
    save_last: True
    dirpath: ${save_dir}/checkpoints/${logger.name}
    filename: "epoch={epoch:03d}"
    auto_insert_metric_name: True