fit:
  trainer:
    logger:
      - class_path: pytorch_lightning.loggers.WandbLogger
        init_args:
          project: diffsurv
          entity: anon
          log_model: True
          tags:
            - mlp
            - risk
            - diffsort
            - case-control-sampling
            - nwtco
            - topk-experiment-diffsort-topk
    callbacks:
      - class_path: pytorch_lightning.callbacks.ModelCheckpoint
        init_args:
          auto_insert_metric_name: False
          monitor: val/topk
          mode: max
          filename: '{epoch}-{hp_metric:.4f}'
      - class_path: pytorch_lightning.callbacks.EarlyStopping
        init_args:
          patience: 10
          monitor: val/topk
          min_delta: 0.000
          mode: max
          verbose: True
          check_finite: False
          check_on_train_epoch_end: True
    max_epochs: 100
    check_val_every_n_epoch: 1
    log_every_n_steps: 1
    enable_checkpointing: True
    num_sanity_val_steps: 0
    accumulate_grad_batches: 1
  data:
    wandb_artifact: anon/diffsurv/nwtco.pt:latest
    batch_size: 64
    risk_set_size: 32
    val_batch_size: 32
    num_workers: 4
  model:
    head_layers: 2
    embedding_dim: 5
    batch_norm: False
    head_hidden_dim: 128
    lr: 0.0001 # from auto lr find
    only_covs: True
    hidden_dropout_prob: 0.2
    sorting_network: bitonic
    steepness: 20.0
    distribution: cauchy
    ignore_censoring: false
    optimize_topk: true
    optimize_combined: false
    norm_risk: true
