defaults:
  - /dataset: shapes3d_uniform.yaml
  - _self_

seed: 42
callbacks:
  lr_monitor:
    _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: epoch

logger:
  _target_: pytorch_lightning.loggers.TensorBoardLogger
  save_dir: ${hydra:run.dir}
  name: classifier

trainer:
  _target_: trainer_xt_multihead.ClassifierTrainer
  num_classes_per_label: [10, 10, 10, 8, 4, 15]
  epochs: 50
  model:
    _target_: models.MultiheadClassifier
    base_model:
      _target_: torchvision.models.resnet18
    num_classes_per_label:  ${trainer.num_classes_per_label}
  optimizer:
    _target_: torch.optim.AdamW
    _partial_: true
    lr: 3.0e-4
    weight_decay: 1.0e-3
  scheduler:
    _target_: torch.optim.lr_scheduler.CosineAnnealingLR
    _partial_: true
  noise_scheduler:
    _target_: diffusers.DDPMScheduler
    num_train_timesteps: 1000
    clip_sample: true
    prediction_type: epsilon
    beta_schedule: linear


training_config:
  accelerator: gpu
  gradient_clip_val: 1.0
  enable_checkpointing: true
  precision: 16-mixed