
data_local: 
data_remote:

max_seq_len: 128
tokenizer_name: bert-base-uncased
mlm_probability: 0.3
lr: 8.0e-4

# Run Name
run_name: hf-bert-base-uncased-nova-inputdep-lr${lr}-mlm${mlm_probability}

# Model
model:
  name: hf_bert
  use_pretrained: false # Train the model from scratch. Set to true to start from the HF off-the-shelf weights.
  pretrained_model_name: ${tokenizer_name}
  tokenizer_name: ${tokenizer_name}
  # This implementation generally uses the default architecture values for from the Hugging Face BertConfig object
  # These values can be changed here when pretraining from scratch. Note that these should only be used
  # if used_pretained: false, otherwise the model will not be loaded properly
  model_config:
    num_attention_heads: 12 # bert-base default
    num_hidden_layers: 12 # bert-base default
    max_position_embedding: 512
    attention_probs_dropout_prob: 0.1 # bert-base default
    arch_type: nova-inputdep
    position_embedding_type: 'relative_key_query' #ytchanged just dummy value, not used in nova
    layer_norm_after: false

# Dataloaders
train_loader:
  name: text
  dataset:
    local: ${data_local}
    remote: ${data_remote}
    split: train
    tokenizer_name: ${tokenizer_name}
    max_seq_len: ${max_seq_len}
    shuffle: true
    mlm_probability: ${mlm_probability}
  drop_last: true
  num_workers: 8

eval_loader:
  name: text
  dataset:
    local: ${data_local}
    remote: ${data_remote}
    split: val
    tokenizer_name: ${tokenizer_name}
    max_seq_len: ${max_seq_len}
    shuffle: false
    mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison
  drop_last: false
  num_workers: 8

# Optimization
scheduler:
  name: linear_decay_with_warmup
  t_warmup: 0.06dur # Warmup to the full LR for 6% of the training duration
  alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration

optimizer:
  name: decoupled_adamw
  lr: ${lr} # Peak learning rate
  betas:
  - 0.9
  - 0.98
  eps: 1.0e-06
  weight_decay: 1.0e-5 # Amount of weight decay regularization

max_duration: 286720000sp # Subsample the training data for ~275M samples
eval_interval: 1000ba
global_train_batch_size: 4096

# System
seed: 17
device_eval_batch_size: 128
device_train_microbatch_size: 128
# device_train_microbatch_size: auto
precision: amp_bf16

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 10ba


callbacks:
  speed_monitor:
    window_size: 500
  lr_monitor: {}

# (Optional) W&B logging
#ytchanged
loggers:
  wandb:
    project: nova_pretrain      # Fill this in
    name: hf-bert-base-uncased-nova-inputdep-lr${lr}-mlm${mlm_probability}
    # entity:   # Fill this in

# (Optional) Checkpoint to local filesystem or remote object store
save_interval: 1000ba
save_num_checkpoints_to_keep: 10 
save_folder:  ./local-bert-checkpoints/{run_name}

# (Optional) Load from local filesystem or remote object store to
# start from an existing model checkpoint;
# e.g. './ckpt/latest-rank{rank}.pt' (local), or
# 's3://mybucket/mydir/ckpt/latest-rank{rank}.pt' (remote)
# load_path: null