# Training mode configuration
training_mode: 'llada'  # 'llada' or 'dream'

# Model and data path configuration
paths:
  model: 'GSAI-ML/LLaDA-8B-Instruct'
  experiment: 'ckpt_llada_instruct'
  data:
    bs: 'Lansechen/bs17k_collection_filtered_hard_maxlength600'
    bs_easy: 'Lansechen/bs17k_collection_filtered_easy_maxlength600'

denoiser:
  encoder:
    name: 'dream'
    mask_id: 151666

  decoder:
    wiinit: true
    name: 'eagle_rope'
    num_blocks: 1
    seq_len: &seq_len 512
    input_dim: 3584
    hidden_dim: &dim 3584
    vocab_size: 152064
    block:
      seq_len: *seq_len
      hidden_dim: *dim
      num_heads: 32

train:
  # Will use paths.experiment path
  decoder_resume_path:
  head_resume_path:
  skipped_keys:
  global_step:
  exp_name: &exp_name 'llada_ddt_maskteacher'
  wandb_proj: *exp_name
  output_dir: 'ddt_test'
  logging_dir: 'logs'
  mixed_precision: 'fp16'
  gradient_accumulation_steps: 5
  report_to: 'wandb'
  block_size: 16
  
  lr: 1e-5
  num_iters: 50000
  eval_every: 100000
  save_every: 1000

  enable_shift: true
  share_steps: 2
  self_align: true
  feature_align: false
  self_step: true

data:
  name: 'bs17k' #['numinamath', 'bs17k']
  batch_size: 1
  max_length: *seq_len