defaults:
  - model: rtt_diff
  - data: mnist
  - _self_

data:
  train:
    augmentation: False

lr: 1e-3
n_epochs: 1000
batch_size: 64

n_samples: 64
num_workers: 8
eval_every: 1

compile: false
compile_kwargs:
  # mode: reduce-overhead
  mode: null
  options:
    matmul-padding: True

optim:
  cls: torch.optim.AdamW
  kwargs:
    lr: ${lr}
    weight_decay: 5e-4
    amsgrad: True
    fused: False

use_scheduler: False
scheduler:
  cls: lr_scheduler.CosLRScheduler
  kwargs:
    warmup_steps: 4000
    decay_steps: 80000
    
load_ckpt: null

use_amp: False
gradscaler:
  enabled: ${use_amp}
autocast:
  device_type: cuda
  enabled: ${use_amp}
  dtype: float16

clip_grad: True
clip_grad_max_norm: 1.0

seed: 42
gpu: 0
save_path: ./output
wandb:
  project: mnist_inr_fm
  entity: null
  name: null

ode_solver:
  rtol: 1e-5
  atol: 1e-5
  method: euler
  options:
    step_size: 0.001

matmul_precision: high
cudnn_benchmark: False

diffusion:
  predict_xstart: True
  steps: 500
  noise_schedule: linear
  learn_sigma: False

sample_kwargs: {}
ema: True
ema_decay: 0.9999

debug: False
num_gpus: 1