defaults:
  - override /model_cfg@_global_: smollm2/smollm360mi
  - override /data_cfg@_global_: fineweb_edu_dedup_45B
  - override /trainer_cfg@_global_: logging_trainer
  - _self_

model_log_name: transformer1p5Bcustom
model_name_or_path: transformer1p5Bcustom
tokenizer_name_or_path: HuggingFaceTB/SmolLM2-360M

make_dataset_fn:
  disable_caching_datasets: ${disable_caching_datasets}
disable_caching_datasets: true


model_custom_config:
  _target_: custom_models.sparse_models.SparseLlamaConfig
  attention_bias: false
  attention_dropout: 0.0
  bos_token_id: 0
  eos_token_id: 0
  hidden_act: silu
  hidden_size: 2048
  initializer_range: 0.02
  intermediate_size: 5632
  is_llama_config: true
  max_position_embeddings: ${max_seq_length}
  model_type: llama
  num_attention_heads: 32
  num_hidden_layers: 28
  num_key_value_heads: 32
  pretraining_tp: 1
  rms_norm_eps: 1e-05
  rope_interleaved: false
  rope_scaling: null
  rope_theta: 100000
  tie_word_embeddings: true
  torch_dtype: bfloat16
  use_cache: true
  vocab_size: 49152
  sparsity_l1_coeff: ${sparsity_l1_coeff}
  sparsity_train_batch_size: ${train_batch_size}
  sparsity_per_device_train_batch_size: ${per_device_train_batch_size}
  sparsity_max_seq_length: ${max_seq_length}
  sparsity_tokens_until_dead: null
  sparsity_new_intermediate_size: 5632
  sparsity_gated_mlp: true
  sparsity_reinitialize: true
  sparsity_mlp_execution_logic: ${sparsity_mlp_execution_logic}
  mlp_bias: false

sparsity_mlp_execution_logic: "training"

make_tokenizer_fn: ${tokenizer}

bf16: true
tf32: true
num_train_epochs: 1
logging_steps: 1
logging_strategy: steps

max_steps: 30000

save_strategy: "no"
save_steps: 10000
do_eval: false

train_batch_size: 512
per_device_train_batch_size: 8

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false

learning_rate: 1.0e-3
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.95
adam_epsilon: 1e-8
max_grad_norm: 1.0

lr_scheduler_type: "cosine"

warmup_ratio: 0.02

ddp_timeout: 18000

packing: false

custom_class: null
from_pretrained: false

trainer_args:
  warmup_steps: 600

max_seq_length: 2048
per_device_eval_batch_size: 128

save_final_model: true


sparsity_l1_coeff: 2e-5
streaming: true
wandb_project: sparse-llm-reg
wandb_run_name: ${model_log_name}_Gated_${max_seq_length}ctx_lr_${learning_rate}_l1${sparsity_l1_coeff}

