# checkpoint path, only enabled during inference
ckpt_path: 'output/gmem_b_vavae_f16d32/checkpoints/0010000.pt'

# imagenet safetensor data, see datasets/img_latent_dataset.py for details
data:
  data_path: 'data/preprocessed/in1k256/vavae_f16d32/imagenet_train_256'
  # fid reference file, see ADM<https://github.com/openai/guided-diffusion> for details
  fid_reference_file: 'data/fids/VIRTUAL_imagenet256_labeled.npz'
  image_size: 256
  num_classes: 1000
  num_workers: 8
  # latent normalization, originated from our previous research FasterDiT <https://arxiv.org/abs/2410.10356>
  # The standard deviation of latents directly affects the SNR distribution during training
  # Channel-wise normalization provides stability but may not be optimal for all cases.
  latent_norm: true
  latent_multiplier: 1.0

# our pre-trained vision foundation model aligned VAE. see VA-VAE <to be released> for details. 
vae:
  model_name: 'vavae_f16d32'
  downsample_ratio: 16

# We explored several optimization techniques for transformers:
model:
  model_type: LightningDiT-B/1
  use_qknorm: false
  use_swiglu: true
  use_rope: true
  use_rmsnorm: true
  wo_shift: false
  in_chans: 32

# training parameters
train:
  max_steps: 80000
  # We use large batch training (1024) with adjusted learning rate and beta2 accordingly
  # this is inspired by AuraFlow and muP.
  global_batch_size: 1024
  global_seed: 0
  output_dir: 'output'
  exp_name: 'gmem_b_vavae_f16d32'
  ckpt: null
  log_every: 100
  ckpt_every: 20000

optimizer:
  lr: 0.0002
  beta2: 0.95
  max_grad_norm: 1.0

# we use rectified flow for fast training.
transport:
  # We inherit these settings from SiT, no parameters are changed
  path_type: Linear
  prediction: velocity
  loss_weight: null
  sample_eps: null
  train_eps: null

  # Inspired by SD3 and our previous work FasterDiT
  # In small-scale experiments, we enable lognorm
  # In large-scale experiments, we disable lognorm at the mid of training
  use_lognorm: true
  # cosine loss is enabled at all times
  use_cosine_loss: true

  # REPA settings
  proj_loss_weight: 0.5
  
sample:
  mode: SDE
  sampling_method: Heun
  diffusion_form: linear
  diffusion_norm: 1
  last_step: Mean
  last_step_size: 0.004
  num_steps: 50
  cfg_scale: 1
  num_sampling_steps: 50
  per_proc_batch_size: 32
  fid_num: 50000

  # cfg interval, it is inspired by <https://arxiv.org/abs/2404.07724>
  cfg_interval_start: 0.89
  # timestep shift, it is inspired by FLUX. please refer to transport/integrators.py ode function for details.
  timestep_shift: 0.3
GMem:
  bank_type: 'full'
  bank_path: 'data/preprocessed/in1k256/banks/in1k256_Kfull_dim768_init_seed0.pth.pth'
