model:
  base_learning_rate: 1.0e-4
  target: sgm.models.diffusion.DiffusionEngine
  params:
    denoiser_config:
      target: sgm.modules.diffusionmodules.denoiser.Denoiser
      params:
        scaling_config:
          target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
          params:
            sigma_data: 1.0

    network_config:
      target: sgm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        in_channels: 1
        out_channels: 1
        model_channels: 32
        attention_resolutions: []
        num_res_blocks: 4
        channel_mult: [1, 2, 2]
        num_head_channels: 32
        num_classes: sequential
        adm_in_channels: 128

    conditioner_config:
      target: sgm.modules.GeneralConditioner
      params:
        emb_models:
          - is_trainable: True
            input_key: cls
            ucg_rate: 0.2
            target: sgm.modules.encoders.modules.ClassEmbedder
            params:
              embed_dim: 128
              n_classes: 10

    first_stage_config:
      target: sgm.models.autoencoder.IdentityFirstStage

    loss_fn_config:
      target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
      params:
        loss_weighting_config:
          target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
          params:
            sigma_data: 1.0
        sigma_sampler_config:
          target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling

    sampler_config:
      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
      params:
        num_steps: 50

        discretization_config:
          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization

        guider_config:
          target: sgm.modules.diffusionmodules.guiders.VanillaCFG
          params:
            scale: 3.0

data:
  target: sgm.data.mnist.MNISTLoader
  params:
    batch_size: 512
    num_workers: 1

lightning:
  modelcheckpoint:
    params:
      every_n_train_steps: 5000

  callbacks:
    metrics_over_trainsteps_checkpoint:
      params:
        every_n_train_steps: 25000

    image_logger:
      target: main.ImageLogger
      params:
        disabled: False
        batch_frequency: 1000
        max_images: 16
        increase_log_steps: True
        log_first_step: False
        log_images_kwargs:
          use_ema_scope: False
          N: 16
          n_rows: 4

  trainer:
    devices: 0,
    benchmark: True
    num_sanity_val_steps: 0
    accumulate_grad_batches: 1
    max_epochs: 20