# configs/trainer/torchjd.yaml
module: trainers.torchjd_trainer
class: TorchJDTrainer
args:
  epochs: 500
  batch_size: 100
  log_interval: 10
  device: cuda:0
  save_dir: checkpoints/torchjd
