# config for training GPT-2 (124M) down to very nice loss of ~2.85 on 1 node of 8X A100 40GB
# launch as the following (e.g. in a screen session) and wait ~5 days:
# $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py

wandb_log = True
# wandb_project = 'seed-corrected-projection-sweep'
wandb_project = 'fig1-pretraining-sweep'
wandb_run_name='gpus8-r4-acc40-b8'

# these make the total batch size be ~0.5M
# 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
batch_size = 8
block_size = 1024
gradient_accumulation_steps = 5 * 8

# data
dataset_dir = '[TODO]'


# +++++ Optimizer +++++ #
beta1=0.9 # ? based on https://github.com/microsoft/LoRA/tree/main/examples/NLG
beta2=0.999 # ? based on https://github.com/microsoft/LoRA/tree/main/examples/NLG
learning_rate = 2e-4 # ? based on https://github.com/microsoft/LoRA/tree/main/examples/NLG
weight_decay = 1e-2 # ? based on https://github.com/microsoft/LoRA/tree/main/examples/NLG
correct_bias = True # ? based on https://github.com/microsoft/LoRA/tree/main/examples/NLG
adam_epislon = 1e-06 # ? based on https://github.com/microsoft/LoRA/tree/main/examples/NLG
no_decay_bias = False # ? based on https://github.com/microsoft/LoRA/tree/main/examples/NLG

# this makes total number of tokens be 300B
max_iters = 600000
lr_decay_iters = 600000

# eval stuff
eval_interval = 1000
eval_iters = 200 
log_interval = 10

# model
n_layer = 12
n_head = 12
n_embd = 768


# TODO: check lora parameters
# ===== LoRA Settings ===== #
# 0. Basics
use_lora=True # true if any LoRA layer is used

lora_alpha = 32
lora_dropout = 0.1
# enable_lora = (True, False, True)
# fan_in_fan_out=True
# merge_weights=False

# 1. attention linear layer
attention_linear_use_lora = True
attention_linear_lora_r = 4
# attention_linear_enable_lora=(True, False, True)

# 2. linear head
linear_head_lora_r = 4
linear_head_enable_lora = True

intrinsic_dim = 0

# sentence perturbation
perturb_word_order_window_size=0

