trainer:
  logger:
    - class_path: pytorch_lightning.loggers.WandbLogger
      init_args:
        project: diffsurv
        entity: anon
        log_model: True
        tags:
          - mlp
          - risk
          - diffsort
          - case-control-sampling
  callbacks:
    - class_path: pytorch_lightning.callbacks.ModelCheckpoint
      init_args:
        auto_insert_metric_name: False
        monitor: hp_metric
        filename: '{epoch}-{hp_metric:.2f}'
    - class_path: pytorch_lightning.callbacks.EarlyStopping
      init_args:
        patience: 50
        monitor: hp_metric
        min_delta: 0.002
        mode: max
        verbose: True
        check_finite: False
  max_epochs: 100
  check_val_every_n_epoch: 1
  log_every_n_steps: 100
  enable_checkpointing: True
  num_sanity_val_steps: 0
  accumulate_grad_batches: 1
data:
  wandb_artifact: anon/diffsurv/flchain.pt:latest
  batch_size: 64
  risk_set_size: 2
  num_workers: -1
model:
  head_layers: 1
  embedding_dim: 5
  batch_norm: False
  head_hidden_dim: 128
  lr: 0.005 # from auto lr find
  only_covs: True
  hidden_dropout_prob: 0.2
