model:
  name: "gpt2"

architecture:
  value_dim: 128
  n_intervention_tokens: 5
  extract_layer: 7  
  n_self_attn_layers: 2
  n_heads: 4
  dropout: 0.1
  use_attention_pooling: false
  use_transformer_aggregate: false

generator:
  use_vlp: true         
  vlp_n_heads: 8          
  use_transformer_projector: false
  transformer_n_layers: 2

training:
  stage1:
    batch_size: 8
    lr: 1e-4
    n_epochs: 5
  stage2:
    batch_size: 8
    lr_new: 5e-4
    lr_finetune: 1e-5
    n_epochs: 5
  stage3:
    batch_size: 4
    lr_new: 5e-4
    lr_finetune: 1e-5
    n_epochs: 5
    use_gradient_delta: true
    gradient_step_size: 1.0
    lambda_ce: 0.5     
    lambda_safe: 2.0    
    lambda_reg: 0.1     
    max_grad_norm: 1.0
    log_interval: 100

paths:
  checkpoint_dir: "checkpoints/gpt2"
  data_dir: "data/processed"

