[job]
dump_folder = "./outputs"
description = "Warmup checkpoint creation for Mosaic Llama"
print_args = false

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 1
disable_color_printing = false
enable_tensorboard = false
save_tb_folder = "tb"
enable_wandb = true
save_for_all_ranks = true

[model]
name = "mosaic_llama3_mup"
flavor = "16M"
hf_assets_path = "./tests/assets/tokenizer"

[optimizer]
name = "GaLore"
builder = "mosaic"
lr = 0.016
eps = 1e-8
weight_decay = 0.0
betas = [0.9, 0.999]
galore_param_regexes = [
  { param_str_match = "attention\\.w[qkv]|attention\\.wo|feed_forward\\.w[12]", rank = 8 },
]
implementation = "for-loop"

[fl_metrics.optimizer_monitor]
interval = 16
only_global = false
log_metrics = true

[fl_metrics.activation_monitor]
interval = 16
ignore_module_types = ["dropout", "ln"]

[fl_metrics.lr_monitor]
enabled = true
interval = 16

[fl_metrics.betas_monitor]
enabled = true
interval = 16

[fl_metrics.vs_monitor]
enabled = true
interval = 16

[lr_scheduler]
warmup_steps = 2048
decay_ratio = 0.0
decay_type = "sqrt"
min_lr_factor = 0.0
switch_step = 2048
switch_scale = 1.0

[training]
local_batch_size = 32
global_batch_size = 64
seq_len = 2048
max_norm = 1.0
steps = 2048
dataset = "c4_test"

[parallelism]
data_parallel_replicate_degree = 4
data_parallel_shard_degree = -1
fsdp_reshard_after_forward = "default"
tensor_parallel_degree = 1
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
context_parallel_degree = 1

[checkpoint]
enable = true
keep_latest_k = 5
folder = "checkpoints"
interval = 2048
last_save_model_only = false
export_dtype = "float32"
async_mode = "async_with_pinned_mem"

[s3_checkpoint]
enable = true
bucket = "checkpoints"
prefix = ""
download_on_start = false
# resume_from_run_step = ""

[activation_checkpoint]
mode = "selective"
selective_ac_option = '2'

[compile]
enable = false
components = ["model", "loss"]

[quantize.linear.float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]

[fault_tolerance]
enable = false
process_group = "gloo"
process_group_timeout_ms = 999999
replica_id = 0
group_size = 4
min_replica_size = 4
sync_steps = 32
semi_sync_method = "desloc"

[validation]
enable = false

[unigram_metric]
enable = true
download_missing = true
allow_failures = false
ignore_index = -100
num_attempts = 1

[mosaic_tokenizer]
name = "HuggingFaceTB/SmolLM-1.7B"

[mosaic_tokenizer.kwargs]
model_max_length = 2048

[mosaic_dataloader]
name = "text"
num_workers = 1
pin_memory = true
persistent_workers = false
isolate_grouped_streams = true

[mosaic_dataloader.dataset.common]
max_seq_len = 2048
download_retry = 2
download_timeout = 60
keep_zip = false
partition_algo = "relaxed"
shuffle = true
shuffle_algo = "py1e"
shuffle_seed = 9176
sampling_method = "balanced"
sampling_granularity = 1
batching_method = "random"

[mosaic_dataloader.dataset.train]
split = "train"
root_remote = "s3://smollm-corpus/shared"
root_local = "/nfs-share/datasets/photon/dataset_cache/smollm-corpus-shared-val"
sampling_groups_mode = "grouped"

[mosaic_dataloader.dataset.train.streams.client_streams.stream_0]
local = "fineweb_edu_dedup/client_0"
remote = "fineweb_edu_dedup/client_0"
proportion = 70

[mosaic_dataloader.dataset.train.streams.client_streams.stream_1]
local = "fineweb_edu_dedup/client_1"
remote = "fineweb_edu_dedup/client_1"
proportion = 70

[mosaic_dataloader.dataset.train.streams.client_streams.stream_2]
local = "fineweb_edu_dedup/client_2"
remote = "fineweb_edu_dedup/client_2"
proportion = 70

[mosaic_dataloader.dataset.train.streams.client_streams.stream_3]
local = "fineweb_edu_dedup/client_3"
remote = "fineweb_edu_dedup/client_3"
proportion = 70

[mosaic_dataloader.dataset.train.streams.client_streams.stream_4]
local = "cosmo/client_0"
remote = "cosmo/client_0"
proportion = 30

[mosaic_dataloader.dataset.train.streams.client_streams.stream_5]
local = "cosmo/client_1"
remote = "cosmo/client_1"
proportion = 30

[mosaic_dataloader.dataset.train.streams.client_streams.stream_6]
local = "cosmo/client_2"
remote = "cosmo/client_2"
proportion = 30

[mosaic_dataloader.dataset.train.streams.client_streams.stream_7]
local = "cosmo/client_3"
remote = "cosmo/client_3"
proportion = 30
