_base_: [
  ./_base_/default.yaml,
  ./_base_/training/default.yaml,
  ./_base_/dataset/openimage_kodak.yaml,
  ./_base_/model/beta_cond_interp_ca_elic_charm.yaml,
]

# Start with the pretrained model from stage 2. Update the path if necessary.
pretrained_weight_path: ./checkpoint/crdr_stage2/model/comp_model_iter1000K.pth.tar

# HRRGAN Trainer
trainer:
  type: MultirateBetaCondHrrGanRateDistortionTrainer

discriminator:
  # 1 discriminator for each rate_level
  type: ModuleListDiscriminator
  _subd_type: CLIC21GVAEDiscriminator
  _num_subd: 5 # rate_level = 5
  in_ch: 3
  out_ch: 1
  main_ch: 64
  norm_type: none

loss:
  rate_loss:
    type: HificVariableRateLoss
    lambda_A: [3.4, 1.3, 0.4, 0.12, 0.05]
    lambda_B: 0.015625 # 2 ** (-6)
    target_rate: [0.0, 0.0, 0.0, 0.0, 0.0] # target rate is not used in this stage
  distortion_loss:
    type: MSELoss
    loss_weight: 150
  perceptual_loss:
    type: LPIPSLoss
    net: alex
    loss_weight: 0.390625 # = 1.0 / 2.56 (1.0 when beta=2.56)
  gan_loss:
    type: VanillaGANLoss
    loss_weight: 0.000390625 # = 0.001 / 2.56 (0.001 when beta=2.56)

optim:
  g_optimizer:
    type: Adam
    lr: 0.0001
  g_scheduler:
    type: MultiStepLR
    milestones: [2400000]
    gamma: 0.1
  d_optimizer:
    type: Adam
    lr: 0.0001
  d_scheduler:
    type: MultiStepLR
    milestones: [2400000]
    gamma: 0.1

total_iter: 3000000

keep_step: [
  2980000,
  2990000,
  3000000,
]