program: main.py
command:
  - ${env}
  - python3
  - ${program}
  - ${args}
method: grid
metric:
  name: validation/val/perplexity
  goal: minimize
parameters:
  log:
    value: wandb
  task:
    value: c4_flopmatched_parthead_routing_transformer
  test_interval:
    value: 20000
  state_size:
    value: 1024
  transformer.ff_multiplier:
    value: 4
  transformer.n_layers:
    value: 9
  sa_moe.baseline_dense_heads:
    values:
      - 9
  sa_moe.shared_dense_heads:
    values:
      - 4
  sa_moe.sparsity:
    values:
      - 2
      - 4
      - 8
      - 16
      - 32
      - 64
  sa_moe.include_first:
    value: 0
  transformer.head_projection_size:
    value: 64
  dropout:
    value: 0.0
  lr:
    value: 0.00025
  optimizer:
    value: adamw
  lm.unroll:
    value: 1024
  batch_size:
    value: 64
  grad_clip:
    value: 0.25
  amp:
    value: 1
  save_interval:
    value: 600000
  stop_after:
    value: 100000
  lr_sched.type:
    value: cos
  lr_warmup:
    value: 4000
  wd:
    value: 0.01
  min_lr_multiplier:
    value: 0.1
  compile:
    value: 1