model:
  base_learning_rate: 1.0e-4
  target: ldm.models.autoencoder_joint.VQDiGANJoint
  params:
    monitor: val/rec_loss
    embed_dim: 6
    n_embed: 1024
    image_key: tensor
    ckpt_path: path/to/model/vqae.ckpt
    num_directions: 6
    ddconfig:
      double_z: False
      z_channels: 6
      resolution: 80
      in_channels: 1
      out_ch: 1
      ch: 64
      ch_mult: [1,2,4]  # f = 2 ^ len(ch_mult)
      num_res_blocks: 2
      cond_type: max_cross_attn
      attn_type: max
      attn_resolutions: []
      dropout: 0.0
      num_classes_encoder: 4
      num_classes_decoder: 2
      num_directions: 6
    lossconfig:
      target: ldm.modules.losses.vqperceptual.VQLPIPSJoint
      params:
        pixelloss_weight: 1.0

data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 2
    num_workers: 8
    wrap: false
    train:
      target: ldm.data.diffusion_joint.DiffusionJoint
      params:
        dataroot: path/to/data
        stage: train
    validation:
      target: ldm.data.diffusion_joint.DiffusionJoint
      params:
        dataroot: path/to/data
        stage: val


lightning:
  callbacks:
    image_logger:
      target: main.ImageLogger
      params:
        batch_frequency: 1000
        val_batch_frequency: 100
        max_images: 6
        increase_log_steps: False
        log_images_kwargs: {'N': 1}
    finetune_decoder_joint:
      target: ldm.models.autoencoder_joint.FinetuneDecoderJoint
      params:
        unfreeze_at_epoch: 0

  trainer:
    accumulate_grad_batches: 8
    benchmark: True
    max_epochs: -1
