model:
  base_learning_rate: 1e-4
  target: pit.models.autoencoder.AutoencodingEngine
  params:
    input_key: img
    loss_config:
      target: pit.modules.losses.discriminator_loss.GeneralLPIPSWithDiscriminator
      params:
        perceptual_weight: 1.0
        disc_start: 20001
        disc_weight: 0.75
        learn_logvar: True
    
        regularization_weights: null

        additional_log_keys: null

        discriminator_config:
          target: pit.modules.lpips.model.model.NLayerDiscriminator
          params:
            input_nc: 3
            ndf: 160
            n_layers: 6
            use_actnorm: True

    regularizer_config:
      target: pit.quantization.gaussian.GaussianQuantRegularizer
      params:
        format: blc
        group: 8
        n_samples: 65536
        seed: 42
        beta: 0.0
        backend: torch
    
    encoder_config:
      target: pit.modules.vit.TransformerEncoder
      params:
        double_z: true
        z_channels: 16
        image_size: 256
        patch_size: 8
        width: 768
        layers: 12
        heads: 12
        mlp_ratio: 4
        drop_rate: 0.
    
    decoder_config:
      target: pit.modules.vit.TransformerDecoder
      params: ${model.params.encoder_config.params}

    clamp_range: [-1,1]

lightning:

  modelcheckpoint:
    params:
      every_n_train_steps: 5000
      save_top_k: -1

  callbacks:
    metrics_over_trainsteps_checkpoint:
      params:
        every_n_train_steps: 10000
  
    image_logger:
      target: main.ImageLogger
      params:
        disabled: False
        enable_autocast: False
        batch_frequency: 1000
        max_images: 8
        increase_log_steps: True
        log_first_step: True
        log_images_kwargs:
          N: 8
          n_rows: 2
  
  trainer:
    devices: 0,1,2,3,4,5,6,7
    num_nodes: 1
    precision: 32
    benchmark: True
    num_sanity_val_steps: 0
    accumulate_grad_batches: 1
    max_epochs: 10000

data:
  target: pit.data.ImageDataModuleFromConfig
  params:
    num_workers: 16
    batch_size: 16

    train:
      target: pit.data.SimpleDataset
      params:
        root: ...
        image_size: 256
