seed_everything: 42
num_evaluation_seeds: 5

max_predictions: 32
num_classes: 35  # Zero is excluded by PTLS.
max_time_delta: 32
max_duration: 32
time_smoothing: null  # Time is continuous.

rnn_hidden_size: 64
rnn_inference_context: 64
rnn_inference_context_step: 32

transformer_hidden_size: 64
transformer_heads: 4
transformer_layers: 2
transformer_pos_m: 1  # ~ minimum delta.
transformer_pos_M: 5000  # ~ maximum delta.

max_intensity: null
ode_lipschitz: 1
thinning_params:
  max_steps: 100
  max_delta: ${max_time_delta}
  bound_samples: 64
next_k: 16
hypro_loss_step: 4

logger:
  _target_: pytorch_lightning.loggers.WandbLogger
  project: hotpp-mimiciv
  name: ${name}
  save_dir: lightning_logs
model_path: checkpoints/${name}.ckpt
report: results/${name}.yaml
multiseed_report: results/multiseed_${name}.yaml

data_module:
  _target_: hotpp.data.HotppDataModule
  batch_size: 64
  max_length: 64
  num_workers: 4
  train_path: data/train.parquet
  val_path: data/val.parquet
  test_path: data/test.parquet

metric:
  _target_: hotpp.metrics.HorizonMetric
  max_time_delta: ${max_time_delta}
  horizon: 28  # 4 weeks.
  horizon_evaluation_step: 4
  map_deltas: [4]  # Days
  map_target_length: 16
  otd_steps: 5  # Half the minimum length.
  otd_insert_cost: 2
  otd_delete_cost: 2

head:
  _target_: hotpp.nn.Head
  _partial_: true
  hidden_dims: [64]
  use_batch_norm: true

conditional_head:
  _target_: hotpp.nn.ConditionalHead
  _partial_: true
  hidden_dims: [128, 256]
  use_batch_norm: true

module:
  seq_encoder:
    _target_: hotpp.nn.RnnEncoder
    embeddings:
      labels:
        in: ${num_classes}
        out: 16
    max_time_delta: ${max_time_delta}
  optimizer_partial:
    _partial_: true
    _target_: torch.optim.Adam
    lr: 0.001
    weight_decay: 0.0
  lr_scheduler_partial:
    _partial_: true
    _target_: torch.optim.lr_scheduler.StepLR
    step_size: 5
    gamma: 0.8
  val_metric: ${metric}
  test_metric: ${metric}

trainer:
  accelerator: cuda
  devices: 1
  max_epochs: 30
  enable_checkpointing: true
  deterministic: true
  precision: 16-mixed
  gradient_clip_val: 1  # Increases training stability.
  check_val_every_n_epoch: 3
  model_selection:
    metric: val/T-mAP
    mode: max
