# d3llm_train.yaml
model:
    name: "GSAI-ML/LLaDA-8B-Instruct"
    trust_remote_code: true
    torch_dtype: "bfloat16"

training:
    output_dir: "output_model/d3LLM_LLaDA"
    num_train_epochs: 6
    gradient_accumulation_steps: 4
    per_device_train_batch_size: 4
    logging_steps: 10
    learning_rate: 0.00002
    weight_decay: 0.01
    bf16: True
    optim: "adamw_torch"
    # warmup_ratio: 0.05
    max_grad_norm: 1
    group_by_length: false
    lr_scheduler_type: "constant"
    save_strategy: "epoch"
    # W&B logging configuration
    report_to: "wandb"
    run_name: "d3llm_llada_training"
    logging_first_step: true

# LoRA configuration (optional, disabled by default)
lora:
    enabled: true # Set to true to enable LoRA training
    r: 256
    lora_alpha: 256
    target_modules:
        - "q_proj"
        - "k_proj"
        - "v_proj"
        - "o_proj"
        - "gate_proj"
        - "up_proj"
        - "down_proj"
    lora_dropout: 0.0
    bias: "none"
    task_type: "CAUSAL_LM"

distillation:
    trajectory_dataset_path: "trajectory_data_llada_32" # Path to trajectory dataset
    max_length: 384 # Maximum sequence length for training
    use_naive_random_mask: false # If true: use naive random masking baseline instead of trajectory selection
    use_complementary_loss: true # If true: add complementary CE loss (dParallel style)
    progressive_block_sizes: [16, 16, 24, 24, 32, 32] # Block sizes for each epoch (len = num_epochs)
    min_mask_ratio: 0.0 # Minimum mask ratio for progressive sampling
    max_mask_ratio: 0.8 # Maximum mask ratio for progressive sampling
    temperature: 0.5 # Temperature for entropy regularization
    entropy_weight: 2.0 # Weight for entropy loss
    num_proc: 32 # Number of processes for dataset filtering and mapping
