defaults:
  - base
  - data: cifar10
  - task: inpaint
  - model: warm_start_fm
  - _self_

execution:
  train_steps: 3000000
  gradient_clip_norm: 3.0
  accumulate_steps: 4 # Effective batch size = 512, matching the repo
  ema_rate: 0.999

output:
  metrics:
    - _target_: cdnp.evaluate.LossMetric
    - _target_: cdnp.evaluate.FIDMetric
      num_samples: 1000
      means: ${data.dataset.norm_means}
      stds: ${data.dataset.norm_stds}
      device: ${runtime.device}
      nfe: 50
    - _target_: cdnp.evaluate.FIDMetric
      num_samples: 1000
      means: ${data.dataset.norm_means}
      stds: ${data.dataset.norm_stds}
      device: ${runtime.device}
      nfe: 10
    - _target_: cdnp.evaluate.FIDMetric
      num_samples: 1000
      means: ${data.dataset.norm_means}
      stds: ${data.dataset.norm_stds}
      device: ${runtime.device}
      nfe: 5

model:
  warm_start_model:
    _target_: cdnp.util.instantiate.load_model_from_path
    path: ${paths.weights}/cifar10_cnp
    freeze: true
    device: ${runtime.device}
  generative_model:
    backbone:
      in_channels: 13 # noised RGB image + masked RGB input + mask + CNP mean + CNP stddev
      out_channels: 3 # RGB image
      use_warmth_embedding: true

  min_warmth: 0.0
  max_warmth: 1.0


optimizer:
  lr: 1e-4
  betas: [0.9, 0.95]

data:
  trainloader:
    batch_size: 128
  preprocess_fn:
    min_frac: 0.1
    max_frac: 0.1