defaults:
  - dataset: mod_subtract_dataset
  - model: grokk_model
  - _self_

dataset:
  frac_train: 0.4
  p: 96

model:
  transformer_config:
    pre_norm: true

train:
  num_workers: 0
  bsize: 512
  lr: 1e-3
  weight_decay: 0.0
  betas: [0.9, 0.98]
  warmup_steps: 10
  eval_every: 10
  eval_batches: 8
  max_steps: 1e6
  seed: 0


wandb:
  use_wandb: true
  wandb_project: grokking_replica

wd:
  enabled: true
  num_pairs: 2
  resolution: 64
  max_degree: 40
  log_every_eval: 1   # Compute WD per evaluation; a value of 10 triggers it every 10 evaluations.
  use_pca: true
  pca_k: 10

sharpness:
  enabled: true
  rho: 0.05
  log_every_eval: 1   # Compute sharpness per evaluation; a value of 10 triggers it every 10 evaluations.