# PyTorch Training Parameters
model:
  model_arch: pix2pixHD
  model_config: 
    GAN_upsample: bilinear # convtrans, bilinear
    GAN_norm: instance # instance, batch
    D_arch: patchGAN
    G_arch: global
    n_layers_D: 3
    num_D: 2
    n_downsample_global: 4
    divisible: 8

datasets:
  datasets_folder: "./datasets_preprocess"
  train_datasets: 
    # - BosonPlus_day
    # - DJI_day
    # - Boson_night
    # - BosonPlus_night
    # - Caltech
    # - LLVIP
    # - NII_CU
    # - TARDAL
    # - Freiburg_day
    # - Freiburg_night
    # - MSRS
    # - KAIST
    # - SMOD_day
    # - SMOD_night
    # - FLIR
    - AVIID
  val_datasets: 
    # - BosonPlus_day
    # - Boson_night
    # - BosonPlus_night
    # - LLVIP
    # - NII_CU
    # - MSRS
    # - TARDAL
    # - FLIR
    - AVIID
  test_datasets:
    - BosonPlus_day
    - Boson_night
    - BosonPlus_night
    - LLVIP
    - NII_CU
    - Freiburg_day
    - Freiburg_night
    - MSRS
    - TARDAL
    - FLIR
    - AVIID
  target_val_dataset: BosonPlus_day

training:
  num_epochs: 1
  num_samples_per_epoch: 10000
  train_batch_size: 8
  test_batch_size: 8
  num_workers: 8
  val_freq: 10
  train_image_size: [256, 256]
  optimizer:
    name: AdamW
    lr: 2e-4
    weight_decay: 0.0
    momentum: 0.9
  scheduler:
    name: linear
    args:
      start_factor: 1.0
      end_factor: 0.25
      total_iters: 300
  loss:
    name: 'pix2pixHD'
    config:
      GAN_mode: lsgan # vanilla, lsgan
      G_loss_lambda: 10.0
  mixed_precision: False
  load: ./checkpoints/pix2pixhd/checkpoints/last.ckpt