# config for training GPT-2 (124M) down to very nice loss of ~2.85 on 1 node of 8X A100 40GB
# launch as the following (e.g. in a screen session) and wait ~5 days:
# $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py

wandb_log = True
wandb_project = "owt"
wandb_run_name = "gpt2-124M"

# checkpoint
out_dir = "ckpt_dir"

# these make the total batch size be ~0.5M
# 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
batch_size = 12
block_size = 1024
gradient_accumulation_steps = 5 * 8

# this makes total number of tokens be 300B
max_iters = 600000
lr_decay_iters = 600000

# eval stuff
eval_interval = 1000
eval_iters = 200
log_interval = 10

# hyper parameters
weight_decay = 1e-1
## LR
# default 6e-4
# smaller 1e-5 ~ 6e-5
# larger 6e-3 for apollo
learning_rate = 6e-4
min_lr = learning_rate * 0.1

opt_type = "adamw"
assert opt_type in ["adamw", "adam-mini", "apollo", "apollo-mini"]
bit_lambda = 1e-4
init_bit = 6.0
target_bit = 4.0
is_diffq = False
