model_path: google/gemma-2-2b
default_batch_size: 32
seq_len: 128
dtype: ${dtype:torch.bfloat16}
default_module_names:
  icml24:
    - "model.layers.*.mlp.up_proj"
    - "model.layers.*.mlp.down_proj"
    - "model.layers.*.mlp.gate_proj"
  layernorm: ['.+layernorm']
  post_layernorm: [".*post_attention_layernorm", ".*post_feedforward_layernorm"]
  fast: ['model.layers.0.mlp.down_proj']
  default: ${.post_layernorm}