name: multi-round-sdtt
log_loss_buckets: -1

checkpoint_path: 'kuleshov-group/mdlm-owt'
start_from_hf: False

# distillation mode: mse / tvd / kl-fwd / kl-bwd
distill_mode: kl-bwd
num_distill_steps: 2
min_num_sampling_steps: 8
grow_dt_every: 10_000

orig_num_sampling_steps: 1024
sampling_mode: ancestral
loss_precision: null


reset_optimizer_on_growth: False
use_ema_on_growth: False

sampling:
  uncond:
    run: True
    # Shared
    num_samples: 1024
    batch_size: 16
    from_ema: False

    # Passed to `sample`
    #n_samples: ${..batch_size}
    num_steps: 8
    seq_len: 1024
    sampler: ancestral
    cache_preds: False
    add_bos: False
    add_eos: False

  cond_prefix:
    run: False
    # Shared
    num_samples: 1024
    batch_size: 16
    from_ema: False

    dataset: webtext
    seq_len: 100
    prefix_len: 50
    num_cont_per_prefix: 5
    min_seq_len: 1024

    num_steps: 8
    sampler: ancestral
    cache_preds: False
    add_bos: False
    add_eos: False
    
    calc_sbleu: True
