---
name: dbrx-18b
data_local: ./my-copy-c4
data_remote: null  # If blank, files must be present in data_local
max_seq_len: 2048
global_seed: 17
autoresume: false

# Run Name
run_name: ${run_uuid}  # If left blank, will be read from env var $RUN_NAME

# Evaluation gauntlet
eval_gauntlet: ${eval_gauntlet_config.eval_gauntlet}
icl_tasks_config: ${icl_tasks_config}
icl_tasks: ${icl_tasks_config.icl_tasks}

# Model
model:
  name: mpt_causal_lm
  d_model: 3072
  n_heads: 24
  no_bias: true
  n_layers: 22
  norm_type: low_precision_layernorm
  ffn_config:
    ffn_type: torch_dmoe
    mlp_type: glu
    moe_top_k: 4
    ffn_act_fn:
      name: silu
    moe_jitter_eps: 0
    moe_num_experts: 16
    uniform_expert_assignment: false
    moe_normalize_expert_weights: 1
  vocab_size: 50368
  attn_config:
    rope: true
    alibi: false
    clip_qkv: 8
    attn_impl: flash
    attn_type: grouped_query_attention
    kv_n_heads: 8
    rope_theta: 500000
    attn_uses_sequence_id: false
  init_device: meta
  max_seq_len: ${llm_config.max_seq_len}
  param_init_fn: kaiming_normal_
  expansion_ratio: 1.75
  init_nonlinearity: relu
  use_train_metrics: false
  fuse_norm_attn_norm: true
  tie_word_embeddings: true
  activation_checkpointing_target:
    - norm_attn_norm

# Tokenizer
tokenizer:
  name: EleutherAI/gpt-neox-20b
  kwargs:
    model_max_length: ${llm_config.max_seq_len}

# Dataloaders
train_loader:
  name: text
  dataset: ${dataset.train}
  drop_last: true
  num_workers: 8

eval_loader:
  name: text
  dataset: ${dataset.val}
  drop_last: false
  num_workers: 8

# Optimization
scheduler:
  schedulers:
    lr:
      name: cosine_with_warmup
      t_warmup: 100ba
      alpha_f: 0.1
      t_max: ${llm_config.max_duration}

optimizer:
  name: decoupled_adamw
  lr: 1.0e-4
  betas: [0.9, 0.95]
  eps: 1.0e-08
  weight_decay: 0.0

algorithms:
  gradient_clipping:
    clipping_type: norm
    clipping_threshold: 1.0

max_duration: 124000ba  # ~ 270B tokens
eval_interval: 5000ba
eval_first: false
eval_subset_num_batches: -1
global_train_batch_size: 1024

# System
seed: ${llm_config.global_seed}
device_eval_batch_size: 8
device_train_microbatch_size: 8
# device_train_microbatch_size: auto
precision: amp_bf16

# FSDP
fsdp_config:
  verbose: false
  data_parallel_shard_degree: 8
  mixed_precision: PURE
  state_dict_type: sharded
  use_orig_params: true
  limit_all_gathers: true
  sharding_strategy: FULL_SHARD
  activation_cpu_offload: false
  activation_checkpointing: false
  activation_checkpointing_reentrant: false

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

callbacks:
  speed_monitor:
    window_size: 20
    gpu_flops_available: null
  lr_monitor: {}
  memory_monitor: {}
  runtime_estimator: {}
  activation_monitor_full_model: {interval: 10ba}
  optimizer_monitor: {interval: 10ba, only_global: true}

loggers:
  wandb:
    init_kwargs: ${wandb.setup}
  tensorboard: {flush_interval: 10}

# Checkpoint to local filesystem or remote object store
save_interval: 500ba
save_num_checkpoints_to_keep: 1  # Important, this cleans up checkpoints saved to DISK
save_folder: ./{run_name}/checkpoints
save_overwrite: false

# Load from local filesystem or remote object store
load_path: null
