
eval_interval: 1000 # keep frequent because we'll overfit
eval_iters: 200
log_interval: 10 # don't print too too often

# we expect to overfit on this small dataset, so only save when val improves
always_save_checkpoint: False

wandb_log: True # disabled by default
wandb_project: 'MX_nanoGPT_pretrain'
wandb_entity: 'harvardml'
wandb_group: 'gp2_format_sweep'
wandb_run_name: 'gpt2' # 'run' + str(time.time())

gradient_accumulation_steps: 1

learning_rate: 0.001 # 1e-3 # with baby networks can afford to go a bit higher
max_iters: 600000
lr_decay_iters: 600000 # make equal to max_iters usually, should be ~= max_iters per Chinchilla
min_lr: 0.0001 # 1e-4 # learning_rate / 10 usually
beta2: 0.99 # make a bit bigger because number of tokens per iter is small

warmup_iters: 100 # not super necessary potentially

defaults:
  - _self_

datasets:
  name: algebraic-stack
  path: "/n/holylfs06/LABS/kempner_shared/Everyone/testbed/text/fineweb-edu/tokenized/meta-llama-3/default"
  tokenizer_used: t5-base

model:
  name: gpt
  n_head: 4
  d_model: 512
  n_layer: 6
  dropout_p: 0.0
  context_length: 1024
  autocast_precision: bfloat16
  # "float16”, “fp16”, “bfloat16”, “bf16”, “fp8_e5m2”, “fp8_e4m3”, “fp6_e3m2”, “fp6_e2m3”, "fp4_e2m1”, “fp4”, “int8”, “int4"
  w_mx_format: 'int8'
  a_mx_format: 'int8'
  flash: False
  flex: True
  mlp_scale_factor: 4
  mlp_bias: True
  attn_bias: False
  proj_bias: True
  ln_bias: True
  cls_head_bias: True
  activation: relu
  mask: causal_document
  compile_flex: True

optimizer:
  name: AdamW
  lr: 5e-5 # 0.0001  # max learning rate
  min_lr: 1e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
  weight_decay: 0.1
  betas: [0.9, 0.999]
  eps: 1e-8
  precision: float32
  
tokenizer:
  name: t5-base
  vocab_size: 32128


training:
  epochs: 1
  train_steps: 100000 # do whatever is smaller, train_steps or epoch
  batch_size: 48 # if gradient_accumulation_steps > 1, this is the micro-batch size
  log_interval: 20
  val_interval: 100
  shuffle: True
  save_model: True
  save_every: 3600 # in seconds (saves state every hour)
  artifacts_path: tmrc_dev_artifacts
  use_oracle: False
  torch_profiling: False

profiler:
  wait: 1
  warmup: 3
  active: 1
  repeat: 1

# wandb_log:
#   name: tmrc_dev_log202410

HydraConf:
  version_base: "1.1"
