defaults:
  - default
  - _self_

name: hypro_rmtpp

data_module:
  batch_size: 256

module:
  _target_: hotpp.modules.hypro.HyproModule
  seq_encoder:
    rnn_partial:
      _target_: hotpp.nn.GRU
      _partial_: true
      hidden_size: ${rnn_hidden_size}
    max_inference_context: ${rnn_inference_context}
    inference_context_step: ${rnn_inference_context_step}
  head_partial: ${head}
  loss:
    _target_: hotpp.losses.NextItemLoss
    prediction:
      timestamps: mean
      labels: sample
    losses:
      timestamps:
        _target_: hotpp.losses.TimeRMTPPLoss
        init_influence: -0.1
        max_intensity: ${max_intensity}
        thinning_params: ${thinning_params}
      labels:
        _target_: hotpp.losses.CrossEntropyLoss
        num_classes: ${num_classes}
  autoreg_max_steps: ${max_predictions}
  base_checkpoint: checkpoints/rmtpp-seed-${seed_everything}.ckpt
  hypro_encoder:
    _target_: hotpp.nn.RnnEncoder
    embeddings:
      labels:
        in: ${num_classes}
        out: 16
    rnn_partial:
      _target_: hotpp.nn.GRU
      _partial_: true
      hidden_size: 16
    max_time_delta: ${max_time_delta}
  hypro_head_partial:
    _target_: hotpp.nn.Head
    _partial_: true
    hidden_dims: [16]
  hypro_loss:
    _target_: hotpp.losses.hypro.HyproBCELoss
  hypro_loss_step: ${hypro_loss_step}
  hypro_context: ${max_predictions}
