program: main.py
command:
  - ${env}
  - python3
  - ${program}
  - ${args}
method: grid
metric:
  name: validation/mean_accuracy
  goal: maximize
parameters:
  name:
    value: dm_math_diff
  batch_size:
    value: 128
  dm_math.task:
    value: calculus__differentiate
  log:
    value: wandb
  lr:
    value: 0.0001
  mask_loss_weight:
    value: 0.00001
  profile:
    value: deepmind_math
  step_per_mask:
    value: 40000
  stop_after:
    value: 40000
  sweep_id_for_grid_search:
    distribution: categorical
    values:
    - 1
    - 2
    - 3
    - 4
    - 5
  test_interval:
    value: 5000
  mask_lr:
    value: 0.03