dataset: gene_crispr
model: explicit-transformer

task:
  _target_: tasks.gene_crispr.GeneCrisprICL
  lr: 1e-4
  model:
    _target_: models.explicit.ExplicitModelWith
    context_model:
      _target_: models.explicit.TransformerContext
      x_dim: 5000
      y_dim: 0
      n_features: 256
      n_heads: 4
      n_hidden: 512
      n_layers: 4
      dropout: 0.0
    prediction_model:
      _target_: models.explicit.TransformerPrediction
      x_dim: 5000
      y_dim: 5000
      z_dim: ${.n_features}
      n_features: 256
      n_heads: 4
      n_hidden: 512
      n_layers: 4
      dropout: 0.0

data:
  _target_: data.gene_crispr.GeneCrisprDataModule
  data_path: "~/data"
  contexts_per_ptb: 50
  n_context: 40
  n_queries: 5
  query_dim_pct: 0.5
  perturb_type: "both"
  include_control: True
  batch_size: 128
  train_size: 0.8
  num_workers: 0 

trainer:
  max_epochs: 1000

callbacks:
  - _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: "val/MSE"
    mode: "min"
    save_last: True
    dirpath: ${save_dir}/checkpoints/${logger.name}
    filename: "epoch={epoch:03d}"
    auto_insert_metric_name: True