model:
  name: "meta-llama/Llama-3.2-3B-Instruct"

architecture:
  value_dim: 256
  n_intervention_tokens: 7
  extract_layer: 20  
  n_self_attn_layers: 2
  n_heads: 16
  dropout: 0.1
  use_attention_pooling: true
  use_transformer_aggregate: true

generator:
  use_vlp: true           
  vlp_n_heads: 16          
  use_transformer_projector: false
  transformer_n_layers: 2

training:
  stage1:
    batch_size: 8  
    lr: 1e-4
    n_epochs: 5
  stage2:
    batch_size: 8
    lr_new: 5e-4
    lr_finetune: 1e-5
    n_epochs: 5
  stage3:
    batch_size: 4
    lr_new: 5e-4
    lr_finetune: 1e-5
    n_epochs: 5
    use_gradient_delta: true
    gradient_step_size: 1.0
    lambda_ce: 0.5     
    lambda_safe: 2.0    
    lambda_reg: 0.1    
    max_grad_norm: 1.0
    log_interval: 100

paths:
  checkpoint_dir: "checkpoints/llama3.2_3b_instruct"
  data_dir: "data/processed"

