data:
  dataset: imagenet256-latent
  category: lmdb
  resolution: 32
  num_channels: 4
  random_flip: True
  root: ImageNet256
  feat_path: None

model:
  precond: edm_mi
  model_type: DiT-XL/2
  in_size: 32
  in_channels: 4 
  num_classes: 1000
  use_decoder: True
  ext_feature_dim: 0
  pad_cls_token: False
  mask_ratio: 0.5
  mask_ratio_fn: constant
  mask_ratio_min: 0
  mae_loss_coef: 0.1
  class_dropout_prob: 0.1

train:
  tf32: False
  amp: True
  batchsize: 128   # batchsize per GPU
  grad_accum: 1
  epochs: 2800
  lr: 0.0001
  lr_rampup_kimg: 0
  xflip: False
  max_num_steps: 2000000

eval: # FID evaluation
  batchsize: 50
  ref_path: assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz

log:
  log_every: 50
  ckpt_every: 10000
  tag: pretrain 

wandb:
  entity: MaskDiT
  project: MaskDiT-B-ImageNet256-latent-train
  group: pretrain
