# Elijah Cole, Nicholas Ho
seed_everything: 42
data:
  class_path: modelgenerator.data.CellClassificationDataModule
  init_args:
    path: '/workspace/modelgenerator/cell-downstream-tasks/zheng/'
    # path: '/workspace/modelgenerator/cell-downstream-tasks/Segerstolpe/'
    batch_size: 16
    train_split_files: 
      - 'zheng_train.h5ad'
    valid_split_files:
      - 'zheng_valid.h5ad'
    test_split_files:
      - 'zheng_test.h5ad'
    # train_split_files: 
    #   - 'Segerstolpe_train.h5ad'
    # valid_split_files:
    #   - 'Segerstolpe_valid.h5ad'
    # test_split_files:
    #   - 'Segerstolpe_test.h5ad'
    filter_columns:
      - 'cell_type_label'
    rename_columns:
      - 'labels'
model:
  class_path: modelgenerator.tasks.SequenceClassification
  init_args:
    n_classes: 11 # zheng
    # n_classes: 13 # Segerstolpe
    use_legacy_adapter: false
    optimizer:
      class_path: torch.optim.AdamW
      init_args:
        lr: 1e-3
        weight_decay: 0.01
        betas: [0.9, 0.95]
    lr_scheduler:
      class_path: modelgenerator.lr_schedulers.CosineWithWarmup
      init_args:
        warmup_ratio: 0.01
    backbone:
      class_path: modelgenerator.backbones.aido_cell_3m
      init_args:
        from_scratch: False
    adapter:
      class_path: modelgenerator.adapters.LinearMaxPoolAdapter
trainer:
  log_every_n_steps: 10
  precision: bf16
  devices: auto
  max_epochs: 10
  gradient_clip_val: 0
  profiler: null
  callbacks: 
    # Save a checkpoint for max val f1:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      dict_kwargs:
        monitor: val_f1
        mode: max
        save_top_k: 1
        filename: "best_val_f1:{step}-{val_f1:.3f}"
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      dict_kwargs:
        monitor: train_loss
        mode: min
        save_top_k: 1
        filename: "best_train_loss:{step}-{train_loss:.3f}"
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      dict_kwargs:
        monitor: val_loss
        mode: min
        save_top_k: 1
        filename: "best_val_loss:{step}-{val_loss:.3f}"
    # Save latest:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      dict_kwargs:
        filename: "latest:{step}"
  devices: 1
  default_root_dir: '/workspace/modelgenerator/logs'
  strategy:
    class_path: lightning.pytorch.strategies.DDPStrategy
# TODO: Clean up parameter dependencies. 

