output_dir: './outputs'

data:
  data_path: './buffers/data/vavae_f16d32/train_256'
  fid_reference_file: './buffers/refs/in1k256_fid_ref.npz'
  image_size: 256
  num_classes: 1000
  num_workers: 8
  latent_norm: True
  latent_multiplier: 1.0

vae:
  type: 'vavae_f16d32'
  downsample_ratio: 16

model:
  type: TiT-XL/1
  in_chans: 32

train:
  ckpt: "0000000.pt" # <---- ckpt from the 500K multi-step model
  no_reopt: True
  no_reuni: True
  no_buffer: False
  max_steps: 50000
  global_batch_size: 1024
  global_seed: 0
  log_every: 100
  ckpt_every: 10000
  ema_decay: 0.9999

optimizer:
  type: 'RAdam'
  lr: 0.0001
  weight_decay: 0.0
  beta1: 0.9
  beta2: 0.999
  max_grad_norm: 0.1

transport:
  type: Linear
  lab_drop_ratio: 0.1
  consistc_ratio: 1.0
  enhanced_ratio: 2.0
  scaled_cbl_eps: 5.0
  ema_decay_rate: 1.0
  enhanced_range: [0.0, 0.75]
  time_dist_ctrl: [0.8, 1.0, 1.0]
  weight_funcion: Cosine
  wt_cosine_loss: False

sample: # FID=1.42
  ckpt: "0040000.pt"
  type: UNI
  cfg_scale: 0.0
  cfg_interval: [0.00, 0.75]
  sampling_steps: 2
  stochast_ratio: 1.0
  extrapol_ratio: 0.0
  sampling_order: 1
  time_dist_ctrl: [1.0, 1.0, 1.0]
  rfba_gap_steps: [0.001, 0.3]
  per_batch_size: 125
  fid_sample_num: 50000

# sample: # FID=1.21
#   ckpt: "0000000.pt"
#   type: UNI
#   cfg_scale: 0.0
#   cfg_interval: [0.00, 0.75]
#   sampling_steps: 30
#   stochast_ratio: 0.0
#   extrapol_ratio: 0.5
#   sampling_order: 1
#   time_dist_ctrl: [1.0, 1.0, 1.0]
#   rfba_gap_steps: [0.001, 0.001]
#   per_batch_size: 125
#   fid_sample_num: 50000