trainer:
  accelerator: auto
  devices: auto
  max_steps: 10000
  gradient_clip_val: 0.1
  default_root_dir: logs
  strategy:
    class_path: lightning.pytorch.strategies.FSDPStrategy
    init_args:
      auto_wrap_policy: [modelgenerator.huggingface_models.fm4bio.modeling_fm4bio.FM4BioLayer]
      sharding_strategy: HYBRID_SHARD
  logger: false
  callbacks:
  - class_path: lightning.pytorch.callbacks.ModelCheckpoint # save ckpt at the end of each epoch, and save the best val_mcc ckpt
    init_args:
      dirpath: null
      filename: epoch_{epoch}-val_mcc:{val_spearman:.3f}
      monitor: val_spearman
      mode: max
  - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping
    dict_kwargs:
      monitor: val_spearman
      mode: max
      patience: 10
model:
  class_path: modelgenerator.tasks.SequenceRegression
  init_args:
    backbone:
      class_path: modelgenerator.backbones.aido_protein_16b_v1
      init_args:
        use_peft: true
        max_length: 2048
    adapter:
      class_path: modelgenerator.adapters.MLPPoolAdapter
      init_args:
        hidden_sizes:
        - 128
        dropout: 0.1
        dropout_in_middle: false
    optimizer:
      class_path: torch.optim.AdamW
      init_args:
        lr: 0.0001
        weight_decay: 0.01
    lr_scheduler:
      class_path: modelgenerator.lr_schedulers.CosineWithWarmup
      init_args:
        warmup_ratio: 0.05
data:
  class_path: modelgenerator.data.DMSFitnessPrediction
  init_args:
    path: genbio-ai/ProteinGYM-DMS
    train_split_files:
    - singles_substitutions/VRPI_BPT7_Tsuboyama_2023_2WNM.tsv
    train_split_name: train
    random_seed: 42
    batch_size: 32
    cv_num_folds: 5
    cv_test_fold_id: 0
    cv_enable_val_fold: true
    cv_fold_id_col: fold_id
