fit:
  trainer:
    logger:
      - class_path: pytorch_lightning.loggers.WandbLogger
        init_args:
          project: diffsurv
          entity: anon
          log_model: True
          tags:
            - mlp
            - risk
            - case-control-sampling
            - support
    callbacks:
      - class_path: pytorch_lightning.callbacks.ModelCheckpoint
        init_args:
          auto_insert_metric_name: False
          monitor: val/topk
          mode: max
          filename: '{epoch}-{hp_metric:.4f}'
      - class_path: pytorch_lightning.callbacks.EarlyStopping
        init_args:
          patience: 10
          monitor: val/topk
          min_delta: 0.000
          mode: max
          verbose: True
          check_finite: False
          check_on_train_epoch_end: True
    max_epochs: 100
    check_val_every_n_epoch: 1
    log_every_n_steps: 1
    enable_checkpointing: True
    num_sanity_val_steps: 0
    accumulate_grad_batches: 1
  data:
    wandb_artifact: anon/diffsurv/support.pt:latest
    batch_size: 64
    risk_set_size: 32
    val_batch_size: 32
    num_workers: 4
  model:
    head_layers: 2
    embedding_dim: 5
    batch_norm: False
    head_hidden_dim: 64
    lr: 0.005
    only_covs: True
    hidden_dropout_prob: 0.2
