defaults:
  - _self_
  - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
  - /data: openwebtext
  - /model: small
  - /strategy: ddp
  - /noise: loglinear
  - /lr_scheduler: constant_warmup

mode: train  # train / ppl_eval / sample_eval
diffusion: absorbing_state
backbone: dit  # dit / dimamba / ar
parameterization: subs  # subs / d3pm / sedd
time_conditioning: False
T: 0  # 0 (continuous time) / 1000 
subs_masking: False

seed: 2

text_embedder:
  use_text_embedder: False
  model_name: None # nvidia/NV-Embed-v2, sentence-transformers/all-MiniLM-L6-v2
  cond_dropout: 0.0
  # If > 0, when condition is dropped during training, fill with Gaussian noise
  # sampled from N(0, cond_dropout_std^2) instead of zeros. Default 0.0 keeps
  # the original zeroing behavior.
  cond_dropout_std: 0.0
  random_projection_dim: null # If set, applies random projection to this dimension (non-learnable)
  noise: 0.01 # Stddev of Gaussian noise added to condition embeddings during training
  # Fraction of sampling steps [0, 1] during which conditioning is applied.
  # After this fraction, sampling proceeds unconditionally.
  # Important: model must be trained with cond_dropout > 0 and < 1 to use this.
  use_condition_during_sampling_until: 1.0
  # EMA decay for adjusting condition embedding during sampling.
  # new_cond := decay * old_cond + (1 - decay) * embedding(x_0_sample)
  embedding_ema_decay: 0.9
  # Number of times to update condition embedding during sampling (expensive).
  # Updates are spaced roughly evenly across the conditioning-enabled steps.
  num_embedding_updates: 0

embedding_diffusion:
  # Standalone embedding diffusion model config (used by scripts/train_embedding_diffusion.py)
  timesteps: 1000
  hidden_dim: 512
  num_layers: 3
  # net_type: mlp (v1) | transformer (v2)
  net_type: mlp
  seq_len: 8
  num_heads: 8
  t_sampling_exponent: 2.0
  fid_sample_size: 2048  # number of samples to draw for Fréchet distance (clipped to val size)

loader:
  global_batch_size: 512
  eval_global_batch_size: ${.global_batch_size}
  # Note: batch_size and eval_batch_size are **per machine**
  batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
  pin_memory: True

sampling:
  predictor: ddpm_cache  # analytic, ddpm, ddpm_cache, remaskator
  steps: 128
  noise_removal: True
  # TODO(yair): @subham, why aren't these params under `eval`?
  num_sample_batches: 2  # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
  num_sample_log: 2
  semi_ar: False
  stride_length: 1
  num_strides: 1
  sample_embeddings_from: validation  # validation, train, null, gaussian - source for conditioning embeddings
  # When sample_embeddings_from == 'gaussian', load EmbeddingDiffusionModule ckpt and sample embeddings
  gaussian_checkpoint_path: ''
  # Remaskator-guided sampling parameters
  remaskator_temperature: 1.0
  # Temperature applied to denoiser logits during remasking proposal sampling
  denoiser_temp_during_remasking: 1.0
  remaskator_checkpoint_path: ''  # optional path to Remaskator Lightning ckpt or state_dict
  # Apply remaskator-guided updates only when t in [remaskator_t_off, remaskator_t_on]
  remaskator_t_off: 0.05
  remaskator_t_on: 0.55
  # Remaskator training time sampling parameters
  # Choose how to sample t during training of Remaskator: 'uniform' or 'const'
  t_sampling: uniform
  # If t_sampling == 'const', use this t value (clamped to [training.sampling_eps, 1.0])
  t_const: 1.0
  freeze_backbone: False
  nucleus_p: 1.0
  eta: 0.008
  remdm_mode: null # null / "cap"
  save_x0_sample: False
  use_fp64: False
  remaskator_adaptive_sampling: False

remaskator:
  # If set to an integer N, Remaskator will initialize its DIT backbone
  # with only the first N transformer blocks from the checkpoint.
  # Default null uses all blocks specified by model.n_blocks.
  take_fist_n_layrs: null

training:
  ema: 0.9999
  antithetic_sampling: True
  importance_sampling: False
  sampling_eps: 1e-3
  change_of_variables: False
  # If true, weight BCE so 0/1 classes contribute equally in Remaskator loss
  remaskator_reweighting: False

eval:
  checkpoint_path: ''  # Used to evaluate a checkpoint after training.
  disable_ema: False
  compute_generative_perplexity: False
  perplexity_batch_size: 8
  compute_perplexity_on_sanity: False
  gen_ppl_eval_model_name_or_path: gpt2-large  # gpt2-large, meta-llama/Llama-2-7b-hf
  generate_samples: True
  # Save per-step sampling trajectories during evaluation/sampling
  save_sampling_trajectory: True
  # Maximum number of trajectories to save across the entire run (across batches)
  max_trajectories_to_save: 5

optim:
  weight_decay: 0
  lr: 3e-4
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8

trainer:
  _target_: lightning.Trainer
  accelerator: cuda
  num_nodes: 1
  devices: ${device_count:}
  accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
  gradient_clip_val: 1.0
  precision: 'bf16'
  num_sanity_val_steps: 2
  max_steps: ${eval:${trainer.accumulate_grad_batches} * 1_000_000}
  log_every_n_steps: 10
  limit_train_batches: 1.0   # train on full dataset, can be used to toggle quick run
  limit_val_batches: 1.0     # validate on full dataset, can be used to toggle quick run
  val_check_interval: 10000

wandb:
  project: text-diffusion
  notes: Mulan for text
  group: null
  job_type: null
  name: eval_hf_raw_${now:%Y.%m.%d}
  id: ${.name}_${seed}
  tags:
    - ${noise.type}
    - ${data.train}
    - ${data.valid}

hydra:
  run:
    dir: ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
  job:
    chdir: true

checkpointing:
  # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
  save_dir: ${cwd:}
  # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
  resume_from_ckpt: true
  resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt
