seed_everything: 42
num_evaluation_seeds: 5

max_predictions: 32
num_classes: 205
max_time_delta: 8
max_duration: 8
time_smoothing: 1  # Convert discrete time to a continuos variable during loss computation.

rnn_hidden_size: 512
rnn_inference_context: 1200
rnn_inference_context_step: 200

transformer_hidden_size: 512
transformer_heads: 4
transformer_layers: 2
transformer_pos_m: 1  # ~ minimum delta.
transformer_pos_M: 800  # ~ maximum delta.

max_intensity: null
ode_lipschitz: 0.1
thinning_params:
  max_steps: 100
  max_delta: ${max_time_delta}
  bound_samples: 16
next_k: 16
hypro_loss_step: 64

logger:
  _target_: pytorch_lightning.loggers.WandbLogger
  project: hotpp-agepred
  name: ${name}
  save_dir: lightning_logs
model_path: checkpoints/${name}.ckpt
report: results/${name}.yaml
multiseed_report: results/multiseed_${name}.yaml
downstream_report: results/downstream_${name}.txt

data_module:
  _target_: hotpp.data.HotppDataModule
  batch_size: 64
  min_length: 1000
  max_length: 1200
  num_workers: 4
  train_path: data/train.parquet
  val_path: data/val.parquet
  test_path: data/test.parquet

metric:
  _target_: hotpp.metrics.HorizonMetric
  max_time_delta: ${max_time_delta}
  horizon: 7  # One week.
  horizon_evaluation_step: 64
  map_deltas: [2]
  map_target_length: 24
  otd_steps: 5
  otd_insert_cost: 1
  otd_delete_cost: 1

head:
  _target_: hotpp.nn.Head
  _partial_: true
  hidden_dims: [512, 256]
  use_batch_norm: true

conditional_head:
  _target_: hotpp.nn.ConditionalHead
  _partial_: true
  hidden_dims: [512, 256]
  use_batch_norm: true

module:
  seq_encoder:
    _target_: hotpp.nn.RnnEncoder
    embeddings:
      labels:
        in: ${num_classes}
        out: 256
    max_time_delta: ${max_time_delta}
  optimizer_partial:
    _partial_: true
    _target_: torch.optim.Adam
    lr: 0.001
    weight_decay: 0.0
  lr_scheduler_partial:
    _partial_: true
    _target_: torch.optim.lr_scheduler.StepLR
    step_size: 5
    gamma: 0.8
  val_metric: ${metric}
  test_metric: ${metric}

trainer:
  accelerator: cuda
  devices: 1
  max_epochs: 30
  enable_checkpointing: true
  deterministic: true
  precision: 16-mixed
  gradient_clip_val: 1  # Increases training stability.
  check_val_every_n_epoch: 3
  model_selection:
    metric: val/T-mAP
    mode: max
