# PyTorch Training Parameters
model:
  model_arch: vqgan
  model_config: 
    embed_dim: 256
    n_embed: 1024
    double_z: False
    z_channels: 256
    resolution: 256
    in_channels: 3
    out_ch: 1
    ch: 128
    ch_mult: [1,1,2,2,4]  # num_down = len(ch_mult)-1
    num_res_blocks: 2
    attn_resolutions: [16]
    dropout: 0.0
    divisible: 16

datasets:
  datasets_folder: "./datasets_preprocess"
  train_datasets: 
    # - BosonPlus_day
    # - DJI_day
    # - Boson_night
    # - BosonPlus_night
    # - Caltech
    # - LLVIP
    # - NII_CU
    # - TARDAL
    # - Freiburg_day
    # - Freiburg_night
    # - MSRS
    # - KAIST
    # - SMOD_day
    # - SMOD_night
    # - FLIR
    - AVIID
  val_datasets: 
    # - BosonPlus_day
    # - Boson_night
    # - BosonPlus_night
    # - LLVIP
    # - NII_CU
    # - MSRS
    # - TARDAL
    # - FLIR
    - AVIID
  test_datasets:
    - BosonPlus_day
    - Boson_night
    - BosonPlus_night
    - LLVIP
    - NII_CU
    - Freiburg_day
    - Freiburg_night
    - MSRS
    - TARDAL
    - FLIR
    - AVIID
  target_val_dataset: BosonPlus_day

training:
  num_epochs: 1
  num_samples_per_epoch: 10000
  train_batch_size: 8
  test_batch_size: 8
  num_workers: 8
  val_freq: 10
  train_image_size: [256, 256]
  optimizer:
    name: AdamW
    lr: 2e-4
    weight_decay: 0.0
    momentum: 0.9
  scheduler:
    name: linear
    args:
      start_factor: 1.0
      end_factor: 0.25
      total_iters: 300
  loss:
    name: 'vqgan'
    config:
      disc_conditional: True
      disc_in_channels: 4
      disc_start: 200001
      disc_weight: 0.8
      codebook_weight: 1.0
  mixed_precision: False
  load: ./checkpoints/vqgan/checkpoints/last.ckpt