# PyTorch Training Parameters
model:
  model_arch: klvae
  model_config: 
    in_channels: 1
    out_channels: 1
    down_block_types: ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D']
    up_block_types: ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D']
    block_out_channels: [128, 256, 512, 512]
    layers_per_block: 2
    act_fn: silu
    latent_channels: 4
    norm_num_groups: 32
    sample_size: 256
    force_upcast: True
    use_quant_conv: True
    use_post_quant_conv: True
    mid_block_add_attention: True
    divisible: 8

datasets:
  datasets_folder: "./datasets_preprocess"
  train_datasets: 
    # - BosonPlus_day
    # - DJI_day
    # - Boson_night
    # - BosonPlus_night
    # - Caltech
    # - LLVIP
    # - NII_CU
    # - TARDAL
    # - Freiburg_day
    # - Freiburg_night
    # - MSRS
    # - KAIST
    # - SMOD_day
    # - SMOD_night
    # - FLIR
    - AVIID
  val_datasets: 
    # - BosonPlus_day
    # - Boson_night
    # - BosonPlus_night
    # - LLVIP
    # - NII_CU
    # - MSRS
    # - TARDAL
    # - FLIR
    - AVIID
  test_datasets:
    - BosonPlus_day
    - Boson_night
    - BosonPlus_night
    - LLVIP
    - NII_CU
    - Freiburg_day
    - Freiburg_night
    - MSRS
    - TARDAL
    - FLIR
    - AVIID
  target_val_dataset: BosonPlus_day

training:
  num_epochs: 1
  num_samples_per_epoch: 10000
  train_batch_size: 16
  test_batch_size: 8
  num_workers: 8
  val_freq: 50
  train_image_size: [256, 256]
  optimizer:
    name: AdamW
    lr: 6e-5
    weight_decay: 1e-3
    momentum: 0.9
  scheduler:
    name: none
    args:
  loss:
    name: 'klvae'
    config:
      disc_conditional: False
      disc_in_channels: 1
      disc_start: 1
      kl_weight: 1.0e-06
      disc_weight: 0.5
  mixed_precision: False
  load: ./checkpoints/klvae_3rd/checkpoints/last.ckpt