model:
  base_learning_rate: 1e-4
  target: pit.models.autoencoder.AutoencodingEngine
  params:
    input_key: img
    loss_config:
      target: pit.modules.losses.discriminator_loss.GeneralLPIPSWithDiscriminator
      params:
        perceptual_weight: 1.0
        disc_start: 20001
        disc_weight: 0.75
        learn_logvar: True
    
        regularization_weights:
          kl: 1.0

        additional_log_keys: null

        discriminator_config:
          target: pit.modules.lpips.model.model.NLayerDiscriminator
          params:
            input_nc: 3
            ndf: 160
            n_layers: 6
            use_actnorm: True

    regularizer_config:
      target: pit.quantization.gaussian.TargetAdaptativeGaussianRegularizer
      params:
        format: bchw
        group: 16
        target: 16
    
    encoder_config:
      target: pit.modules.unet.Encoder
      params:
        attn_type: vanilla
        double_z: true
        z_channels: 16
        resolution: 256
        in_channels: 3
        out_ch: 3
        ch: 128
        ch_mult: [1, 2, 4, 4]
        num_res_blocks: 2
        attn_resolutions: [32]
        dropout: 0.
    
    decoder_config:
      target: pit.modules.unet.Decoder
      params: ${model.params.encoder_config.params}

lightning:

  modelcheckpoint:
    params:
      every_n_train_steps: 5000
      save_top_k: -1

  callbacks:
    metrics_over_trainsteps_checkpoint:
      params:
        every_n_train_steps: 10000
  
    image_logger:
      target: main.ImageLogger
      params:
        disabled: False
        enable_autocast: False
        batch_frequency: 1000
        max_images: 8
        increase_log_steps: True
        log_first_step: True
        log_images_kwargs:
          N: 8
          n_rows: 2
  
  trainer:
    devices: 0,1,2,3,4,5,6,7
    num_nodes: 1
    precision: 32
    benchmark: True
    num_sanity_val_steps: 0
    accumulate_grad_batches: 1
    max_epochs: 10000

data:
  target: pit.data.ImageDataModuleFromConfig
  params:
    num_workers: 16
    batch_size: 16

    train:
      target: pit.data.SimpleDataset
      params:
        root: ...
        image_size: 256
