program: main.py
command:
  - ${env}
  - python3
  - ${program}
  - ${args}
method: grid
metric:
  name: validation/mix/accuracy/total
  goal: maximize
parameters:
  log:
    value: wandb
  profile:
    value: listops_trafo
  listops.variant:
    value: big
  transformer.variant:
    value: universal
  test_batch_size:
    value: 16
  transformer.encoder_n_layers:
    value: 6
  n_test_layers:
    value: none
  lr:
    value: 0.0004
  state_size:
    value: 256
  dropout:
    value: 0.015
  transformer.attention_dropout:
    value: 0.05
  wd:
    value: 0.05
  transformer.n_heads:
    value: 16
  batch_size:
    value: 512
  optimizer:
    value: adamw
  stop_after:
    value: 200000
  max_length_per_batch:
    value: none
  amp:
    value: 1
  wandb_bug_workaround:
    value: 1
  geometric.pos_encoding:
    value: direction
  ndr.scalar_gate:
    value: 0
  ndr.drop_gate:
    value: 0.15
  grad_clip:
    value: 1
  length_bucketed_sampling:
    value: 1
  transformer.ff_multiplier:
    value: 4
  trafo_classifier.out_mode:
    value: linear
  trafo_classifier.norm_att:
    value: 0
  debug_plot_interval:
    value: none
  sweep_id_for_grid_search:
    distribution: categorical
    values:
      - 1
      - 2
      - 3
      - 4
      - 5
