#----------------------------------------------------------------------------------------------------
# ModelConfig
use_peft: true
torch_dtype: bfloat16
load_in_4bit: true
model_path: LLaDA-sft-s1k-merged

attn_implementation: flash_attention_2

lora_r: 128
lora_alpha: 64
lora_dropout: 0.05
peft_task_type: CAUSAL_LM

# GRPOConfig
dataset: gsm8k
seed: 42
bf16: true

sync_ref_model: True
ref_model_sync_steps: 64

adam_beta1: 0.9
adam_beta2: 0.99
weight_decay: 0.1
max_grad_norm: 0.2

use_vllm: false
vllm_device: auto
vllm_gpu_memory_utilization: 0.9

warmup_ratio: 0.0001
logging_steps: 1
learning_rate: 3e-6
lr_scheduler_type: constant_with_warmup

num_generations: 6
num_train_epochs: 10

eval_strategy: steps
eval_steps: 1000
per_device_eval_batch_size: 6
per_device_train_batch_size: 6
gradient_accumulation_steps: 2


gradient_checkpointing: false
gradient_checkpointing_kwargs:
  use_reentrant: false

resume_from_checkpoint: false

run_name: gsm_sft_bs12
output_dir: /checkpoints/gsm_sft_bs12  # <-- Replace with desired output directory

# Parameters for saving checkpoints
save_steps: 500 # every 500 because I do not have enough space
save_strategy: steps
save_total_limit: 20 

# Diffusion and GRPO-specific
max_completion_length: 256
max_prompt_length: 200
block_length: 32
diffusion_steps: 128
generation_batch_size: 6
remasking: low_confidence
random_masking: True
p_mask_prompt: 0.15
beta: 0.04
epsilon: 0.5
epsilon: 0.5
num_iterations: 12