# @package _global_
defaults:
  - /ae: ssdd
  - _self_


task: ae.train


logging:
  samples:
    n: 8
    every_epoch: 1
    dpi: 100

# === Model arguments ===

ae:
  checkpoint: null
  compile: false

aux_losses:
  _module: ae_aux_losses
  repa:
    model: dinov2_base
    i_extract: 4
    n_layers: 2
  lpips: true
  model_init:
    method: kaiming_normal

# === Training arguments ===

training:
  grad_accumulate: 1
  batch_size: 256
  grad_clip: 0.1
  epochs: 300
  eval_freq: 4
  save_on_best: FID
  resume_from: null

  optimizers:
    main:
      models: [ae, aux_losses]
      name: radamw_sf
      lr: 3e-4
      args:
        weight_decay: 1e-3
      index: 0

  losses:
    diffusion: 1
    repa: 0.25
    lpips: 0.5
    kl: 1e-6


# === Data arguments ===

dataset:
  augs:
    interpolation: lanczos # lanczos / bilinear
    rand_resize_scale: false

test_dataset:
  augs:
    interpolation: bilinear

# === Testing arguments ===

testing:
  batch_size: ${training.batch_size}

  metrics:
    FID: true
    MSE: true
    MAE: true
    LPIPS: true
    PSNR: true
    SSIM: true
    dreamsim: true
    z_stats: false

    MSE_scale: 1000
    MAE_scale: 1
