# config.yaml
optimizer_params:
  - name: adamw
    # lr: [0.00005]
    lr: [0.0005, 0.001, 0.005, 0.01,  0.02]
    weight_decay: 0.1
    lr_schedule: constant-linear
    warm_up_fraction: 0.4

  - name: muon-you
    # lr: [ 0.05]
    lr: [0.005, 0.01, 0.02]
    weight_decay: 0.1
    lr_schedule: constant-linear
    warm_up_fraction: 0.4
    ns_steps: 5

  - name: muon-jordan
    # lr: [ 0.05]
    lr: [0.005, 0.01, 0.02]
    weight_decay: 0.1
    lr_schedule: constant-linear
    warm_up_fraction: 0.4
    ns_steps: 5

  - name: muon-polarexpress
    # lr: [ 0.02]
    lr: [0.005, 0.01, 0.05]
    weight_decay: 0.1
    lr_schedule: constant-linear
    warm_up_fraction: 0.4
    ns_steps: 5


training_params:
  tokens_processed: 524288 # 2^18  # 524288 # 2^19
  val_tokens_processed: 8388608 #2^23
  batch_size: 16 # Could try also 64?
  num_epochs: 1
  context_length: 1024
  gradnorm: 1.0
  tensorcore_precision: high   #Can be highest, high, or medium
  autocast: True
  mixed_precision: bfloat16
  compile: True

logging_params:
  val_tokens_processed: 8388608 #2^23
  log_step: 50
  val_step: 500
  save_ckpt_step: 500
  load_ckpt_step: 0
  keep_last: 2
  ckpt_dir: "outputs/checkpoints"
  results_dir: "outputs/results"
  wandb:
    project: "polar-express"
    dir: "outputs/wandb"
    # tags: ["tag1", "tag2"]
    # notes: "Here are some detailed notes"
    # name: "my fun run"

# Large: 774M params 
# (n_embd: 1280, n_layer: 36, n_head: 20)

gpt_model:
  n_embd: 1280    
  n_layer: 36   
  n_head: 20    
  vocab_size: 50257
  flash_attention: True

dataset:
  name: "fineweb1B"