defaults:
  - _self_
  - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
  - /data: openwebtext
  - /model: small
  - /strategy: ddp
  - /noise: log-linear
  - /lr_scheduler: constant_warmup
  - /prior: none
  - /algo: duo_base

mode: train  # train / ppl_eval / sample_eval
debug: False
seed: 1

adversarial_distill:
  is_distill: False
  accum_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
  fake_frequency: 1
  gradient_clipping_val: 1.0
  checkpoint_path: None
  is_argmax: False
  use_xt_usdm: False
  use_xt_usdm_only_student: False
  gumbel_softmax_relaxation: False
  argmax_reparameterization: False
  process_student: True
  multistep:
    with_data: True
    steps: -1

loader:
  global_batch_size: 512
  eval_global_batch_size: ${.global_batch_size}
  # Note: batch_size and eval_batch_size are **per machine**
  batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
  pin_memory: True

sampling:
  predictor: ancestral  # ancestral_cache (only for MDLM), ancestral, analytic
  steps: 1000
  noise_removal: ancestral  # 'ancestral', 'greedy', 'none'
  use_float64: True
  p_nucleus: 1.0
  num_sample_batches: 2  # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
  num_sample_log: 2
  semi_ar: False
  stride_length: 1
  num_strides: 1

training:
  ema: 0.9999
  antithetic_sampling: True
  importance_sampling: False
  sampling_eps: 1e-3
  change_of_variables: False
  loss_precision: 'bf16'  # bf16, float32, float64
  finetune_path: ''

eval:
  checkpoint_path: ''  # Used to evaluate a checkpoint after training.
  disable_ema: False
  compute_generative_perplexity: False
  perplexity_batch_size: 8
  compute_perplexity_on_sanity: False
  gen_ppl_eval_model_name_or_path: gpt2-large  # gpt2-large, meta-llama/Llama-2-7b-hf
  generate_samples: True
  generated_samples_path: ${cwd:}/samples.json

optim:
  weight_decay: 0
  lr: 3e-4
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8

trainer:
  _target_: lightning.Trainer
  accelerator: cuda
  num_nodes: 1
  devices: ${device_count:}
  accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
  gradient_clip_val: 1.0
  precision: 'bf16'
  num_sanity_val_steps: 2
  max_steps: 1_000_000
  log_every_n_steps: 100
  limit_train_batches: 1.0   # train on full dataset, can be used to toggle quick run
  limit_val_batches: 1.0     # validate on full dataset, can be used to toggle quick run
  val_check_interval: 5000

# wandb:
#   project: flow-ode
#   notes: Flow ODEs for UDLM
#   group: null
#   job_type: null
#   name: null
#   id: ${.name}_${seed}
#   tags:
#     - ${noise.type}
#     - ${data.train}
#     - ${data.valid}
#     - ${algo.name}

logger:
  _target_: lightning.pytorch.loggers.TensorBoardLogger
  save_dir: ${cwd:}/tb_logs
  name: duo

hydra:
  run:
    dir: ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
  job:
    chdir: true

checkpointing:
  # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
  save_dir: ${cwd:}
  # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
  resume_from_ckpt: true
  resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt
