seed_everything: 42
num_evaluation_seeds: 5

max_predictions: 32
num_classes: 3
max_time_delta: 100
max_duration: 100
time_smoothing: 1  # Convert discrete time to a continuos variable during loss computation.

rnn_hidden_size: 64
rnn_inference_context: 264
rnn_inference_context_step: 50

transformer_hidden_size: 64
transformer_heads: 4
transformer_layers: 2
transformer_pos_m: 1  # ~ minimum delta.
transformer_pos_M: 2500  # ~ maximum delta.

max_intensity: null
ode_lipschitz: 0.1
thinning_params:
  max_steps: 100
  max_delta: ${max_time_delta}
  bound_samples: 100
next_k: 16
hypro_loss_step: 8

logger:
  _target_: pytorch_lightning.loggers.WandbLogger
  project: hotpp-retweet
  name: ${name}
  save_dir: lightning_logs
model_path: checkpoints/${name}.ckpt
report: results/${name}.yaml
multiseed_report: results/multiseed_${name}.yaml

data_module:
  _target_: hotpp.data.HotppDataModule
  batch_size: 64
  min_length: 100
  max_length: 100
  num_workers: 4
  train_path: data/train.parquet
  val_path: data/val.parquet
  test_path: data/test.parquet

metric:
  _target_: hotpp.metrics.HorizonMetric
  max_time_delta: ${max_time_delta}
  horizon: 180  # 3 minutes.
  horizon_evaluation_step: 8
  map_deltas: [30]
  map_target_length: 32
  otd_steps: 10
  otd_insert_cost: 15
  otd_delete_cost: 15

head:
  _target_: hotpp.nn.Head
  _partial_: true
  hidden_dims: [64]
  use_batch_norm: true

conditional_head:
  _target_: hotpp.nn.ConditionalHead
  _partial_: true
  hidden_dims: [128, 256]
  use_batch_norm: true

module:
  seq_encoder:
    _target_: hotpp.nn.RnnEncoder
    embeddings:
      labels:
        in: ${num_classes}
        out: 16
    max_time_delta: ${max_time_delta}
  optimizer_partial:
    _partial_: true
    _target_: torch.optim.Adam
    lr: 0.001
    weight_decay: 0.0
  lr_scheduler_partial:
    _partial_: true
    _target_: torch.optim.lr_scheduler.StepLR
    step_size: 5
    gamma: 0.8
  val_metric: ${metric}
  test_metric: ${metric}

trainer:
  accelerator: cuda
  devices: 1
  max_epochs: 30
  enable_checkpointing: true
  deterministic: true
  precision: 16-mixed
  gradient_clip_val: 1  # Increases training stability.
  check_val_every_n_epoch: 3
  model_selection:
    metric: val/T-mAP
    mode: max
