defaults:
  - /diffusion: celeba_dm_sit
  - /dataset: celeba_blond_full
  - _self_

seed: 42
callbacks:
  ema_callback:
    _target_ : score.ema.EMA
    decay: 0.9999
    validate_original_weights: true
    every_n_steps: 100
  lr_monitor:
    _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: epoch
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    filename: "epoch_{epoch}-step_{step}"
  generation_metrics:
    _target_: callbacks.tracker.GenerationMetrics
    sampling_pipe: 
      _target_: score.pipelines.CondDDIMPipeline
      _partial_: true
    vae: 
      _target_: diffusers.AutoencoderKL.from_pretrained
      pretrained_model_name_or_path: black-forest-labs/FLUX.1-schnell
      cache_dir: checkpoints
      subfolder: vae
    metrics: ["quality","cs"] 
    num_classes_per_label: ${diffusion.model.num_class_per_label}
    output_dir: ${hydra:run.dir}  
    classifier: 
      _target_:  cs_classifier.models.MultiLabelClassifier
      base_model: 
        _target_: torchvision.models.resnet18
      num_classes_per_label: ${diffusion.model.num_class_per_label}
    classifer_checkpoint: checkpoints/celeba/cs_celeba_judge/classifier/version_0/checkpoints/epoch=49-step=508700.ckpt    
  jsd:
    _target_: callbacks.tracker.JSDTracker
    num_classes_per_label: ${diffusion.model.num_class_per_label}

loggers:
  tensorboard_logger:
    _target_: pytorch_lightning.loggers.TensorBoardLogger
    save_dir: ${hydra:run.dir}
    name: tensorboard
  csv_logger:
    _target_: pytorch_lightning.loggers.CSVLogger
    save_dir: ${hydra:run.dir}
    name: csv
  wandb_logger:
    _target_: pytorch_lightning.loggers.WandbLogger
    name: null
    project: "LogicalComposition" 
    offline: False 

trainer:
  max_steps: ${diffusion.scheduler.num_training_steps}
  accelerator: gpu
  gradient_clip_val: 1.0
  enable_checkpointing: true