hydra:
  run:
    dir: outputs/pretrain/${hydra:runtime.choices.model}/${hydra:runtime.choices.data}/${run_name}
defaults:
  - model: ???
  - data: ???
  - val_data: null
  - _self_
run_name: ???
seed: 0
tf32: true
compile: false  # set to mode: default, reduce-overhead, max-autotune
ckpt_path: null  # set to "last" to resume training
trainer:
  _target_: lightning.Trainer
  accelerator: auto
  strategy: auto
  devices: 1
  num_nodes: 1
  precision: 32
  logger:
      _target_: lightning.pytorch.loggers.TensorBoardLogger
      save_dir: ${hydra:runtime.output_dir}
      name: logs
  callbacks:
    - _target_: lightning.pytorch.callbacks.LearningRateMonitor
      logging_interval: epoch
    - _target_: lightning.pytorch.callbacks.ModelCheckpoint
      dirpath: ${hydra:runtime.output_dir}/checkpoints
      filename: last
      monitor: epoch
      mode: max
      save_top_k: 1
      every_n_epochs: 10
    - _target_: lightning.pytorch.callbacks.ModelCheckpoint
      dirpath: ${hydra:runtime.output_dir}/checkpoints
      monitor: epoch
      save_weights_only: true
      mode: max
      save_top_k: -1
      every_n_epochs: ${floordiv:${trainer.max_epochs},10}
    - _target_: uni2ts.callbacks.HuggingFaceCheckpoint.HuggingFaceCheckpoint
      dirpath: ${hydra:runtime.output_dir}/HF_checkpoints
      filename: last
      monitor: epoch
      mode: max
      save_top_k: 1
      every_n_epochs: 1
  # epoch-based training provides averaged metrics
  # cannot use max_steps with epoch-based training - resume from checkpoint on wrong epoch
  max_epochs: 1_000
  enable_progress_bar: true
  accumulate_grad_batches: 1
  gradient_clip_val: 1.0
  gradient_clip_algorithm: norm
train_dataloader:
  _target_: uni2ts.data.loader.DataLoader
  batch_size: 64
  batch_size_factor: 2.0
  cycle: true
  num_batches_per_epoch: 100
  shuffle: true
  num_workers: 1
  collate_fn:
    _target_: uni2ts.data.loader.PackCollate
    max_length: ${model.module_kwargs.max_seq_len}
    seq_fields: ${cls_getattr:${model._target_},seq_fields}
    pad_func_map: ${cls_getattr:${model._target_},pad_func_map}
  pin_memory: true
  drop_last: true
  fill_last: false
  worker_init_fn: null
  prefetch_factor: 2
  persistent_workers: true