# Training mode configuration
model: 'llada_base'  # 'llada' or 'dream'
finetuning_method: 'prompt_tuning'
task_name: 'commonsense170k'

finetuning_parameters:
  # Use CAUSAL_LM consistent with the base model architecture.
  # In Prompt Tuning, this inserts trainable vectors at the input layer.
  task_type: "CAUSAL_LM"

  # Kept at 64 to match your P-Tuning configuration for a fair comparison.
  # This ensures both methods have the same "prompt length" budget.
  num_virtual_tokens: 20

  # CRITICAL for Prompt Tuning: 
  # Unlike P-Tuning (which uses an LSTM/MLP to generate prompts), 
  # Standard Prompt Tuning optimizes independent embeddings directly.
  # Without a network to stabilize optimization, random initialization is often unstable.
  # 'TEXT' initialization provides a "warm start" using semantic meaning.
  prompt_tuning_init: "TEXT"

  # The initialization text. 
  # Since num_virtual_tokens is 64, the library will repeat/truncate this text to fit.
  # "Reconstruct..." gives the model a clear semantic instruction for the denoising task.
  prompt_tuning_init_text: "Reconstruct the original clean text from the noisy input sequences:"

  # Verified against LLaDAConfig (Standard LLaMA-like architecture):
  # These help PEFT identify the correct embedding layer dimensions.
  token_dim: 4096
  num_attention_heads: 32
  num_layers: 32

data:
  val_split_seed: 42
  val_split_size: 128
  batch_size: 1
  max_length: 512

train:
  # Will use paths.experiment path
  decoder_resume_path:
  head_resume_path:
  skipped_keys:
  global_step: 
  random_length: False
  global_epoch: 
  global_sample_number:
  global_update_number:
  global_token_number:

  output_dir: 'ckpts'
  logging_dir: 'logs'
  mixed_precision: 'fp16'
  gradient_accumulation_steps: 32
  report_to: 'wandb'
  epoch_num: 1
  
  # P-Tuning requires a higher LR than LoRA (usually 1e-3 vs 1e-4)
  lr: 1e-4
  
  warmup_ratio: 0.05
  eval_every: 64
  eval_from_start: True
  save_every: 64
  per_example_ratio: True
  exp_name:  &exp_name ""
  wandb_proj: ""
  eval:
    # metric: 'accuracy'
    metric: 'loss'
    noise_levels:
    eval_epoches_num: 1
    steps: 32
    gen_length: 512
    block_length: 16
    temperature: 0.0
    cfg_scale: 0.0
    remasking: 'low_confidence'
    observe_steps: True
paths:
  experiment: ''