# general settings
name: DSPF_S2
model_type: DSPF_S2
scale: 1
num_gpu: 1  # set num_gpu: 0 for cpu mode
gpu_ids: [0]
manual_seed: 100
find_unused_parameters: True
# dataset and data loader settings
datasets:
  val:
    name: ValSet
    type: PairedImageFusionDataset_Hybrid
    dataroot_lq_ir: datasets/Hybrid_Datasets/val/IR
    dataroot_gt_ir: datasets/Hybrid_Datasets/val/IR_enhanced
    dataroot_lq_vi: datasets/Hybrid_Datasets/val/VI
    dataroot_gt_vi: datasets/Hybrid_Datasets/val/VI_enhanced
    filename_tmpl: '{}'
    io_backend:
      type: disk
  train:
    name: TrainSet
    type: PairedImageFusionDataset_DSPF
    dataroot_lq_ir: datasets/Hybrid_Datasets/train/IR
    dataroot_gt_ir: datasets/Hybrid_Datasets/train/IR_enhanced
    dataroot_lq_vi: datasets/Hybrid_Datasets/train/VI
    dataroot_gt_vi: datasets/Hybrid_Datasets/train/VI_enhanced
    filename_tmpl: '{}'
    io_backend:
      type: disk

    # data loader
    use_shuffle: true
    num_worker_per_gpu: 4
    batch_size_per_gpu: 4
    dataset_enlarge_ratio: 1
    prefetch_mode: ~

    ### -------------Progressive training--------------------------
    mini_batch_sizes: [4, 3]             # Batch size per gpu   
    iters: [300000, 200000]
    gt_size: 256   # Max patch size for progressive training
    gt_sizes: [224, 256]  # Patch sizes for progressive training.
    ### ------------------------------------------------------------

    dataset_enlarge_ratio: 1
    prefetch_mode: ~
    
# network structures
network_g:
  type: Transformer_DSPF
  inp_channels: 3
  out_channels: 3
  dim: 32
  num_blocks: [2,2,4,4]
  num_refinement_blocks: 1
  heads: [1,1,2,2]    
  ffn_expansion_factor: 2.66
  bias: False
  LayerNorm_type: WithBias
  dual_pixel_task: False
  embed_dim: 64
  type_embed_dim: 32
  group: 4 # N=4*4
  with_contra: False
  with_SFP: False

network_dp:
  type: degradation_encoder_gelu
  in_chans: 3
  embed_dim: 32 # same as above
  block_num: 4
  group: 4 # same as above
  stage: 1
  patch_expansion: 0.5
  channel_expansion: 4

network_sp:
  type: semantic_encoder_gelu
  in_chans: 6
  embed_dim: 64 # same as above
  block_num: 6
  group: 4 # same as above
  stage: 1
  patch_expansion: 0.5
  channel_expansion: 4

network_dm:
  type: denoising
  in_channel: 256 # (embed_dim*4)
  out_channel: 256 # (embed_dim*4)
  inner_channel: 512
  block_num: 6
  group: 4 # same as above
  patch_expansion: 0.5
  channel_expansion: 2
  dp_channel: 128

diffusion_schedule:
  apply_ldm: True
  schedule: linear
  timesteps: 10
  linear_start: 0.1 # 1e-6
  linear_end: 0.99 # 1e-2

# path
path:
  pretrain_network_g: ckpt/net_g.pth
  param_key_g: params
  strict_load_g: true

  pretrain_network_dp: ckpt/net_dp.pth
  param_key_dp: params
  strict_load_dp: true

  pretrain_network_sp: ckpt/net_sp.pth
  param_key_sp: params
  strict_load_sp: true

  pretrain_network_dm: ~
  param_key_dm: params
  strict_load_dm: true

  resume_state: ~

# training settings
train:
  total_iter: 500000
  warmup_iter: -1 # no warm up
  use_grad_clip: True

  scheduler:
    type: CosineAnnealingRestartCyclicLR
    periods: [300000, 200000]       
    restart_weights: [1,1]
    eta_mins: [0.0002,0.000001]
  
  mixing_augs:
    mixup: false
    mixup_beta: 1.2
    use_identity: true

  optim_total:
    type: AdamW
    lr: !!float 2e-4
    weight_decay: !!float 1e-4
    betas: [0.9, 0.999]
  
  # losses
  pixel_opt:
    type: fusion_loss

  pixel_diff_opt:
    type: L1Loss

  # fidelity_opt:
  #   type: Fidelity_loss
# validation settings
val:
  val_freq: !!float 5e3
  save_img: true
  rec_flag: False


  metrics:
    psnr: # metric name, can be arbitrary
      type: calculate_psnr
      crop_border: 0
      test_y_channel: false

# logging settings
logger:
  print_freq: 200
  save_checkpoint_freq: !!float 5e3
  use_tb_logger: true
  wandb:
    project: ~
    resume_id: ~

# dist training settings
dist_params:
  backend: nccl
  port: 5320
