seed_everything: true
trainer:
  accelerator: gpu
  strategy: ddp_find_unused_parameters_true
  devices: 8
  num_nodes: 8
  precision: 16-mixed
  max_epochs: 300
  check_val_every_n_epoch: 1
  num_sanity_val_steps: 0
  log_every_n_steps: 100
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        dirpath: "results/ibq/ibq16384_256"
        save_top_k: -1
        save_last: true
    - class_path: lightning.pytorch.callbacks.LearningRateMonitor
      init_args:
        logging_interval: step
  logger:
    class_path: lightning.pytorch.loggers.TensorBoardLogger
    init_args:
      save_dir: "results/ibq/"
      name: "ibq16384_256"
      version: null

model:
  class_path: taming.models.vq.VQModel
  init_args:
    ddconfig:
      double_z: false
      z_channels: 256
      resolution: 256
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult: [1, 1, 2, 2, 4]
      num_res_blocks: 4
      norm_type: "groupnorm"
      norm_groups: 32

    quantconfig:
      class_path: taming.modules.vqvae.simvq.IBQ
      init_args:
        n_e: 16384
        e_dim: 256
        beta: 0.25
        use_entropy_loss: true
        disentangle_loss_type: "codebook_orth"  # Options: "codebook", "z_q", "codebook_orth", "z_q_orth"
        disentangle_loss_weight: 0.001  # Weight for disentanglement loss
        entropy_temperature: 0.01
        sample_minimization_weight: 1.0
        batch_maximization_weight: 1.0

    lossconfig:
      class_path: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
      init_args:
        disc_conditional: false
        disc_in_channels: 3
        disc_start: 0
        disc_weight: 0.4
        gen_loss_weight: 0.1
        commit_weight: 1.0

    learning_rate: 1e-4
    scheduler_type: "None"
    warmup_epochs: 1.0
    use_ema: true

data:
  class_path: main.DataModuleFromConfig
  init_args:
    batch_size: 4
    num_workers: 16
    train:
      target: taming.data.imagenet.ImageNetTrain
      params:
        config:
          size: 256
          subset:
    validation:
      target: taming.data.imagenet.ImageNetValidation
      params:
        config:
          size: 256
          subset:
    test:
      target: taming.data.imagenet.ImageNetValidation
      params:
        config:
          size: 256
          subset:

ckpt_path: null