variables:
  # Note: This requires ~16x80GB GPUs
  icl_seq_len: 1024

  max_seq_len: 4096

  # Run Name
  run_name:  # If left blank, will be read from env var $RUN_NAME

icl_seq_len: ${variables.icl_seq_len}
max_seq_len: ${variables.max_seq_len}
run_name: ${variables.run_name}

# Model
model:
  name: hf_causal_lm
  pretrained: true
  init_device: mixed
  peft_config:
    r: 64
    peft_type: LORA
    task_type: CAUSAL_LM
    lora_alpha: 128
    lora_dropout: 0.05
    target_modules:
    - Wqkv
  use_auth_token: true
  config_overrides: {}
  use_flash_attention_2: true
  pretrained_model_name_or_path: databricks/dbrx-instruct

# Tokenizer
tokenizer:
  name: databricks/dbrx-instruct
  kwargs:
    model_max_length: ${variables.max_seq_len}
    trust_remote_code: true

# Dataloaders
train_loader:
  name: finetuning
  dataset:
    split: train
    hf_name: mosaicml/dolly_hhrlhf
    shuffle: true
    max_seq_len: ${variables.max_seq_len}
    eos_token_id: 0
    packing_ratio: auto
    allow_pad_trimming: false
    decoder_only_format: true
  drop_last: true
  pin_memory: true
  num_workers: 8
  prefetch_factor: 2
  persistent_workers: true

eval_loader:
  name: finetuning
  dataset:
    split: test
    hf_name: mosaicml/dolly_hhrlhf
    shuffle: false
    max_seq_len: ${variables.max_seq_len}
    packing_ratio: null
    allow_pad_trimming: false
    decoder_only_format: true
  drop_last: true
  pin_memory: true
  num_workers: 8
  prefetch_factor: 2
  persistent_workers: true

# Optimization
optimizer:
  lr: 0.0001
  name: decoupled_lionw
  betas:
  - 0.9
  - 0.95
  weight_decay: 1.0e-06

scheduler:
  name: cosine_with_warmup
  alpha_f: 0
  t_warmup: 0.02dur

algorithms:
  gradient_clipping:
    clipping_type: norm
    clipping_threshold: 1

max_duration: 2ep
eval_interval: 1ep
global_train_batch_size: 16
eval_first: false
# eval_subset_num_batches: -1

# System
seed: 17
device_train_microbatch_size: 1
device_eval_batch_size: 1
precision: amp_bf16
dist_timeout: 3600
expandable_segments: true

# FSDP
fsdp_config:
  mixed_precision: PURE
  state_dict_type: sharded
  limit_all_gathers: true
  sharding_strategy: FULL_SHARD
  activation_cpu_offload: false
  activation_checkpointing: true
  activation_checkpointing_reentrant: false

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

# Callbacks
callbacks:
  lr_monitor: {}
  speed_monitor:
    window_size: 1
  memory_monitor: {}
  hf_checkpointer:
    overwrite: true
    precision: bfloat16
    save_folder: ./{run_name}/checkpoints
    save_interval: 1dur
  runtime_estimator: {}
# Checkpoint to local filesystem or remote object store
# save_interval: 5000ba
# save_num_checkpoints_to_keep: 1  # Important, this cleans up checkpoints saved to DISK
# save_folder: ./{run_name}/checkpoints
# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints

# Logging
# loggers:
#   wandb:
#     name:
#     group:
#   mlflow:
#     tracking_uri:
#     experiment_name:
