defaults:
  - /diffusion: shapes3d_dm
  - /dataset: shapes3d_uniform
  - _self_

seed: 42
callbacks:
  ema_callback:
    _target_ : score.ema.EMA
    decay: 0.9999
    validate_original_weights: true
    every_n_steps: 100
  lr_monitor:
    _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: epoch
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    filename: "epoch_{epoch}-step_{step}"
  generation_metrics:
    _target_: callbacks.tracker.GenerationMetrics
    sampling_pipe: 
      _target_: score.pipelines.CondDDIMPipeline
      _partial_: true
    metrics: ["quality","r2","cs"] 
    num_classes_per_label: ${diffusion.model.num_class_per_label}
    output_dir: ${hydra:run.dir}     
    classifier: 
      _target_: cs_classifier.models.MultiLabelClassifier
      base_model: 
        _target_: torchvision.models.resnet18
      num_classes_per_label: ${diffusion.model.num_class_per_label}
    classifer_checkpoint: checkpoints/cs_classifier/shapes3d.ckpt       
  jsd:
    _target_: callbacks.tracker.JSDTracker
    num_classes_per_label: ${diffusion.model.num_class_per_label}
    log_interval: 4
loggers:
  wandb_logger:
    _target_: pytorch_lightning.loggers.WandbLogger
    name: null
    project: "LogicalComposition" 
    offline: False 
  tensorboard_logger:
    _target_: pytorch_lightning.loggers.TensorBoardLogger
    save_dir: ${hydra:run.dir}
    name: tensorboard
  #csv_logger:
  #  _target_: pytorch_lightning.loggers.CSVLogger
  #  save_dir: ${hydra:run.dir}
  #  name: csv

trainer:
  max_steps: ${diffusion.scheduler.num_training_steps}
  accelerator: gpu
  gradient_clip_val: 1.0
  enable_checkpointing: true
  precision: 16-mixed


