defaults:
  - _self_
  - /callbacks:
      [
        checkpoint_every_n_steps,
        learning_rate_monitor,
        text_throughput,
        grad_norm,
      ]
  - /data: openwebtext-split
  - /model: dit-orig-small
  - /strategy: ddp
  - /noise: loglinear
  - /lr_scheduler: constant_warmup
  # distill-mdlm / ar
  - /parameterization: multi-round-sdtt

mode: train # train / sample / eval / di4c
time_conditioning: False
# T: 0 -> continuous time
# T > 0 -> number of diffusion steps, eg 1000
T: 1024
compile: False
# Set here, so that it scales with number of accum. steps (more accum -> less freq eval)
eval_every: 5_000_000

# di4c configs:
is_di4c: False
is_teacher_di4c: False
latent_bsize: 16
alpha_t: sigmoid # const / sigmoid / linear
alpha_const: 0.1
distil_delta: 0.02
log_unif: False
round: 7

# moment-sampler configs:
use_moment: True
moment_sampler: moment # maskgit / moment
moment_temp: 10
num_substeps: 1
guidance_w: 0
print_sample: False
temp_sampling: False
dim_selection: False
temp_schedule: True
halton: False

data_preprocess:
  data_cache: 0 #REDACTED
  min_seq_len: -1
  # Model's context length
  seq_len: ${model.length}
  group_text: true
  remove_text: true
  num_seqs: -1
  add_bos: False
  add_eos: True
  legacy_start_end_bos: True

loader:
  global_batch_size: 2
  eval_global_batch_size: ${.global_batch_size}
  # Note: batch_size and eval_batch_size are **per machine**
  batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
  pin_memory: True
  persistent_workers: True

training:
  ema: 0.9999
  antithetic_sampling: True
  importance_sampling: False
  sampling_eps: 1e-5
  change_of_variables: False

tokenizer:
  name: gpt2

optim:
  name: adamw
  weight_decay: 0.0
  lr: 3e-5
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8

trainer:
  _target_: lightning.Trainer
  accelerator: cuda
  num_nodes: 1
  devices: ${device_count:}
  accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
  gradient_clip_val: 1.0
  precision: "bf16-mixed"
  num_sanity_val_steps: 0 # 2
  max_steps: 1_000_000
  log_every_n_steps: 10
  limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
  limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
  # Eval every `eval_every` gradient steps, correctly account for gradient accumulations
  val_check_interval: ${eval:${trainer.accumulate_grad_batches} * ${eval_every}}
  check_val_every_n_epoch: null # Ensure that we do eval every val_check_interval batch (if not set, will reject if an epoch has less steps)
  benchmark: true

wandb:
  project: text-diffusion
  notes: ""
  group: null
  job_type: null
  name: null
  tags:
    - ${noise.type}
    - ${data.train}
    - ${data.valid}

hydra:
  run:
    dir: ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
  job:
    chdir: true

checkpointing:
  # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
  save_dir: ${cwd:}
  # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
  resume_from_ckpt: true
  resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt

eval:
  valid:
    n_samples: 16
    num_steps: 128

  ppl_with_ar:
    run: True
    model: gpt2-large
    batch_size: 8

  cond_ppl:
    run: False
    model: gpt2-large
    batch_size: 8

  mauve:
    run: False
    num_rounds: 5
    max_num_tokens: 100
    batch_size: 8
    scaling_factor: 5

  lambada_openai:
    run: False
    batch_size: 32
    num_samples: 32
    from_ema: False
    add_bos: True
    add_eos: True
