# PyTorch Training Parameters
model:
  model_arch: sit
  model_config: 
    vae: ema
    arch: XL
    patch_size: 8
    cfg_scale: 1.0
    cache_rate: 1
    thermal_normalizer: 0.95941
    RGB_normalizer: 0.18215
    divisible: 16
    vae_model: klvae
    injection_args:
      injection_method: cross
      self_attn: True
      injection_position: q
    vae_path: ./ThermalGen/sim5ciq3/checkpoints/last.ckpt
    vae_config:
      in_channels: 1
      out_channels: 1
      down_block_types: ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D']
      up_block_types: ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D']
      block_out_channels: [128, 256, 512, 512]
      layers_per_block: 2
      act_fn: silu
      latent_channels: 4
      norm_num_groups: 32
      sample_size: 256
      force_upcast: True
      use_quant_conv: True
      use_post_quant_conv: True
      mid_block_add_attention: True
    transport_config:
      path_type: Linear
      prediction: velocity
      loss_weight: None

datasets:
  datasets_folder: "./datasets_preprocess"
  train_datasets: 
    # - BosonPlus_day
    # - DJI_day
    # - Boson_night
    # - BosonPlus_night
    # - Caltech
    # - LLVIP
    # - NII_CU
    # - TARDAL
    # - Freiburg_day
    # - Freiburg_night
    # - MSRS
    # - KAIST
    # - SMOD_day
    # - SMOD_night
    # - FLIR
    - AVIID
  val_datasets: 
    # - BosonPlus_day
    # - Boson_night
    # - BosonPlus_night
    # - LLVIP
    # - NII_CU
    # - MSRS
    # - TARDAL
    # - FLIR
    - AVIID
  test_datasets:
    - BosonPlus_day
    - Boson_night
    - BosonPlus_night
    - LLVIP
    - NII_CU
    - Freiburg_day
    - Freiburg_night
    - MSRS
    - TARDAL
    - FLIR
    - AVIID
  target_val_dataset: BosonPlus_day

training:
  num_epochs: 300
  num_samples_per_epoch: 50000
  train_batch_size: 64
  test_batch_size: 16
  num_workers: 8
  val_freq: 50
  train_image_size: [256, 256]
  optimizer:
    name: AdamW
    lr: 1e-4
    weight_decay: 0.0
    momentum: 0.9
  scheduler:
    name: none
    args:
  loss:
    name: 'sit'
    config: None
  mixed_precision: False
  gradient_accumulation: 1
  validation_type: ema
  load: ./checkpoints/sit_xl8/checkpoints/last.ckpt