program: main.py
command:
  - ${env}
  - python3
  - ${program}
  - ${args}
method: grid
metric:
  name: validation/mean_accuracy
  goal: maximize
parameters:
  name:
    value: enwik8
  log:
    value: wandb
  task:
    value: c4_transformer
  test_interval:
    value: 2000
  state_size:
    value: 412
  transformer.ff_multiplier:
    value: 4
  transformer.encoder_n_layers:
    value: 16
  transformer.n_heads:
    value: 2
  dropout:
    value: 0.1
  lr:
    value: 0.00025
  optimizer:
    value: adamw
  lm.unroll:
    value: 256
  batch_size:
    value: 64
  grad_clip:
    value: 0.25
  amp:
    value: 1
  save_interval:
    value: 10000
  transformer.variant:
    value: preln_moe
  stop_after:
    value: 100000
  moe.n_experts:
    value: 16
  moe.expert_size:
    value: 128
  pkm.n_heads:
    value: 4
  transformer.p_drop_layer:
    value: 0.0
  moe.selection_mode:
    value: sigmoid
  moe.perplexity_reg_mode:
    value: global
  moe.reg_type:
    value: entropy
  moe.perplexity_reg:
    value: 0.0001
  moe.norm_expert_sel_init:
    value: 1
  lr_sched.type:
    value: cos
  transformer.head_projection_size:
    value: 76
  moe.att.variant:
    value: full
  moe.att.selection_mode:
    value: sigmoid
  moe.att.q_expert:
    value: 0
  moe.att.v_expert:
    value: 1
  moe.att.k_expert:
    value: 0
  moe.att.enable:
    value: 1
  moe.att.perplexity_reg:
    value: 0
  moe.att.expert_dropout:
    value: 0
  moe.att.k:
    value: 3
  moe.att.n_experts:
    value: 5
  debug_plot_interval:
    value: none
  plot.n_steps:
    value: -128
  transformer.plot_head_details:
    value: 1
  lmds.valid_ratio:
    value: 0.005