lr: 6.e-4 #0.00025
lr_decay_fn: "cosine"
lr_end_value: 6.e-6 # used for linear decay only
train_steps: 10000 # 600000*grad_accum
# warmup_pc: # 0.025 # rought 15000 for 
warmup: 100 #2000
batch_size: 1 #=4*12 # 192 # =24*8 # 256 # 128 # 64 # depends on the number of devices available.
grad_accumulation_steps: 1
mixed_precision:  "no" # "no" # "bf16" 
#epochs: 100
weight_decay: 0.01
dataset_name: "owt"
# dropout each layer
dropout: 0.0
dropout_att: 0.0
prenorm: True
batchnorm: False
hidden_dim: 512 # 1024 #
nlayers: 8 # 6 # 8
nheads: 8
L: 128 # 64 # 128 # 512 # 768 # 516
state_dim: 128
att_block_len: 128 # size of local attention
embed_type: "rope"
attention_type: "xpos_stable_latte_convQR" # "rot_stable_latte" # "rot_latte_mach_simple" # "xpos_causal" # "rope_causal" # "standard_causal" #  "nope_standard_causal" # "rot_stable_latte" # "stable_latte" # "latte" #
block_type: "transformer-qual" # "rwkv" # "griffin" # "mega" #"transformer" # "mamba" # 
unroll: 100
eval_gen_len: 50
max_seq_len: 128
pos_embed_max_len: 4048 #2048
eval_samples: 20
eval_steps:  100 # 100 # 2000 *grad_accum
project: "torch-jax"
entity: "baesian-learning"
# TODO: Note to a better self - Get rid of hardcoded path
# check_path: "/data_user/data/out_latte/owt_64l/checkpoints"
#run_id: "l2exlub6"
wandb_log: False
disable_cache: False # True #