[job]
dump_folder = "./outputs"
description = "Llama 3 debug training with Mosaic streaming"
print_args = false

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 1
disable_color_printing = false
enable_tensorboard = false
save_tb_folder = "tb"
enable_wandb = true
save_for_all_ranks = true

[model]
name = "mosaic_llama3_mup"
flavor = "16M"
# test folder with tokenizer.json, for debug purpose only
hf_assets_path = "./tests/assets/tokenizer"
# converters = ["float8"]


[optimizer]
name = "ADOPT"
lr = 0.0
eps = 1e-6
weight_decay = 0.0
# betas = [0.9,0.95]
# implementation = "foreach"

# [optimizer.desloc]
# enabled = false
# param_sync_every = 32
# optimizer_sync_every = [64,128]
# backup_device = "cpu"
# pin_memory = true

[fl_metrics.optimizer_monitor]
interval = 1
only_global = false
log_metrics = true

[fl_metrics.activation_monitor]
interval = 1
ignore_module_types = ["dropout", "ln"]

[fl_metrics.lr_monitor]
enabled = true
interval = 1

[fl_metrics.betas_monitor]
enabled = true
interval = 1

[fl_metrics.vs_monitor]
enabled = true
interval = 1

[lr_scheduler]
warmup_steps = 50  # lr scheduler warm up, normally 20% of the train steps
decay_ratio = 0.8  # lr scheduler decay ratio, 80% of the train steps
decay_type = "linear"
min_lr_factor = 0.0
switch_step = 100
switch_scale = 0.8

[training]
local_batch_size = 2
global_batch_size = 8
seq_len = 2048
max_norm = 1.0  # grad norm clipping
steps = 512
dataset = "c4_test"  # supported datasets: c4_test (2K), c4 (177M)

# Mosaic-specific configurations are now at the root level
[mosaic_dataloader]
name = "text"
num_workers = 1
prefetch_factor = 2
pin_memory = true
persistent_workers = true
isolate_grouped_streams = true

[mosaic_dataloader.dataset.common]
max_seq_len = 2048
download_retry = 2
download_timeout = 60
keep_zip = false
partition_algo = "relaxed"
shuffle = true
shuffle_algo = "py1e"
shuffle_seed = 9176
sampling_method = "balanced"
sampling_granularity = 1
batching_method = "random"

[mosaic_dataloader.dataset.train]
split = "train"
root_remote = "s3://smollm-corpus/shared"
root_local = "/nfs-share/datasets/photon/dataset_cache/smollm-corpus-shared"
sampling_groups_mode = "grouped"  # set to "concatenate" to merge all sampling groups

[mosaic_dataloader.dataset.train.streams.client_streams.stream_0]
local = "fineweb_edu_dedup/client_0"
remote = "fineweb_edu_dedup/client_0"
proportion = 70

[mosaic_dataloader.dataset.train.streams.client_streams.stream_1]
local = "cosmo/client_0"
remote = "cosmo/client_0"
proportion = 70

[mosaic_dataloader.dataset.train.sampling_groups.group_0]
streams = ["stream_0"]

[mosaic_dataloader.dataset.train.sampling_groups.group_1]
streams = ["stream_1"]


[mosaic_dataloader.dataset.val]
# The validation samples are stored under the "train" split on disk.
split = "train"
root_remote = "s3://smollm-corpus/shared"
root_local = "/nfs-share/datasets/photon/dataset_cache/smollm-corpus-shared-val"
# subset_num_samples = 512
sampling_groups_mode = "grouped"

[mosaic_dataloader.dataset.val.streams.client_streams.stream_0]
local = "fineweb_edu_dedup/client_0"
remote = "fineweb_edu_dedup/client_0"
proportion = 70

[mosaic_dataloader.dataset.val.streams.client_streams.stream_1]
local = "cosmo/client_0"
remote = "cosmo/client_0"
proportion = 70

[mosaic_dataloader.dataset.val.sampling_groups.group_0]
streams = ["stream_0"]

[mosaic_dataloader.dataset.val.sampling_groups.group_1]
streams = ["stream_1"]


[mosaic_tokenizer]
name = "HuggingFaceTB/SmolLM-1.7B"

[mosaic_tokenizer.kwargs]
model_max_length = 2048

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
fsdp_reshard_after_forward = "default" # default / never / always
tensor_parallel_degree = 1
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
context_parallel_degree = 1

[checkpoint]
enable = false
folder = "checkpoints"
interval = 10
last_save_model_only = false
export_dtype = "float32"
async_mode = "async_with_pinned_mem"  # ["disabled", "async", "async_with_pinned_mem"]

[s3_checkpoint]
enable = false
bucket = "checkpoints"
prefix = ""  # Root of bucket
download_on_start = true
# run_uuid and remote_checkpoint_folder will be set via RUN_UUID environment variable

[activation_checkpoint]
mode = "selective"  # ["none", "selective", "full"]
selective_ac_option = '2'  # 'int' = ac every positive int layer or 'op', ac based on ops policy

[compile]
enable=false
components = ["model", "loss"]

[quantize.linear.float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]

[fault_tolerance]
enable = false
process_group = "gloo"
process_group_timeout_ms = 999999
replica_id = 0
group_size = 2
min_replica_size = 2
sync_steps = 32
semi_sync_method = "desloc"  # Options: "diloco", "local_sgd", or comment out for async quorum

[validation]
enable = false
dataset = "c4_validation"
freq = 5
steps = 32

[unigram_metric]
enable = true
download_missing = true
allow_failures = false
ignore_index = -100
num_attempts = 1
