# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

"""Pretrain utilities."""

import dataclasses
from datetime import datetime
import functools
import gc
import inspect
import logging
import math
import os
import sys
from typing import List, Optional

import torch.distributed

from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer
from .log_handler import CustomHandler

# Make default logging level INFO, but filter out all log messages not from MCore.
logging.basicConfig(handlers=[CustomHandler()], level=logging.INFO)
from .theoretical_memory_usage import report_theoretical_memory
import time

# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()
import torch

try:
    from megatron.rl import rl_utils
    has_rl_utils = True
except ImportError:
    has_rl_utils = False
try:
    from megatron.post_training.algos.distillation import (
        get_tensor_shapes_adjust_fn_for_distillation,
    )

    has_nvidia_modelopt = True
except ImportError:
    has_nvidia_modelopt = False

try:
    from nvidia_resiliency_ext.inprocess import CallWrapper
except ImportError:
    CallWrapper = type(None)


from megatron.core import mpu, tensor_parallel
from megatron.core.utils import (
    check_param_hashes_across_dp_replicas,
    get_model_config,
    StragglerDetector,
)
from megatron.core.fp8_utils import correct_amax_history_if_needed
from megatron.training.checkpointing import load_checkpoint
from megatron.training.checkpointing import save_checkpoint
from megatron.training.checkpointing import checkpoint_exists
from megatron.core.full_cuda_graph import FullCudaGraphWrapper
from megatron.core.transformer.cuda_graphs import TECudaGraphHelper
from megatron.core.transformer.module import Float16Module
from megatron.core.distributed import DistributedDataParallelConfig, TorchFullyShardedDataParallelConfig
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel as megatron_FSDP
from megatron.core.optimizer.optimizer import param_group_identifier_keys

try:
    from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP

    HAVE_FSDP2 = True
except ImportError:
    HAVE_FSDP2 = False

from megatron.core.distributed import finalize_model_grads
from megatron.core.enums import ModelType
from megatron.core.optimizer import get_megatron_optimizer, OptimizerConfig
from megatron.core.rerun_state_machine import (
    get_rerun_state_machine,
    destroy_rerun_state_machine,
    RerunDataIterator,
    RerunMode,
)
from megatron.training.initialize import initialize_megatron
from megatron.training.initialize import write_args_to_tensorboard
from megatron.training.initialize import set_jit_fusion_options
from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank
from megatron.legacy.data.data_samplers import build_pretraining_data_loader
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.core.transformer.moe import upcycling_utils
from megatron.core.transformer.moe.moe_utils import track_moe_metrics
from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper
from megatron.core.parallel_state import destroy_global_memory_buffer, destroy_model_parallel
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.num_microbatches_calculator import (
    destroy_num_microbatches_calculator,
    get_current_global_batch_size,
    get_current_running_global_batch_size,
    get_num_microbatches,
    update_num_microbatches
)

from .async_utils import maybe_finalize_async_save
from .utils import (
    append_to_progress_log,
    calc_params_l2_norm,
    check_adlr_autoresume_termination,
    logical_and_across_model_parallel_group,
    reduce_max_stat_across_model_parallel_group,
    is_last_rank,
    print_rank_0,
    print_rank_last,
    report_memory,
    unwrap_model,
    update_use_dist_ckpt,
    to_empty_if_meta_device,
)
from .global_vars import (
    destroy_global_vars,
    get_args,
    get_signal_handler,
    get_timers,
    get_tensorboard_writer,
    get_wandb_writer,
    get_one_logger,
    get_tokenizer,
    get_energy_monitor,
)
from . import one_logger_utils

from . import ft_integration

stimer = StragglerDetector()

from megatron.core.msc_utils import MultiStorageClientFeature, open_file


def destroy_global_state():
    destroy_global_vars()
    destroy_num_microbatches_calculator()
    destroy_global_memory_buffer()
    destroy_model_parallel()
    destroy_rerun_state_machine()


def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0(f'[{string}] datetime: {time_str} ')


def num_floating_point_operations(args, batch_size):
    def calculate_layer_counts():
        """Calculate the number of attention, Mamba, and MLP layers."""
        if args.hybrid_override_pattern:
            counts = {'M': 0, '*': 0, '-': 0}
            for layer_type in args.hybrid_override_pattern:
                if layer_type in counts:
                    counts[layer_type] += 1
            return counts['*'], counts['M'], counts['-']
        else:
            num_attn_layers = round(args.num_layers * args.hybrid_attention_ratio)
            num_mlp_layers = round(args.num_layers * args.hybrid_mlp_ratio)
            num_mamba_layers = args.num_layers - num_attn_layers - num_mlp_layers
            return num_attn_layers, num_mamba_layers, num_mlp_layers

    def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False):
        """Calculate FLOPs for an MLP layer."""
        scale_factor = 3.0 / 2.0 if swiglu else 1.0
        return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2

    def attn_layer_flops(
        batch_size, seq_len, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None
    ):
        """Calculate FLOPs for an attention layer."""
        p = (kv_channels * num_heads / hidden_size) if kv_channels else 1
        g = gqa_groups if gqa else num_heads
        return (
            4
            * batch_size
            * seq_len
            * hidden_size
            * p
            * (hidden_size + (hidden_size * (g / num_heads)) + (seq_len / 2))
        )

    def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16,
                          head_dim=64, num_groups=1, num_heads=128):
        """Calculate FLOPs for a Mamba layer."""
        # Note (rwaleffe): flops estimate for scan should be updated based on new SSD kernels,
        # but small percent of overall layer flops
        d_in = 2 * hidden_size
        if num_heads:
            nheads = num_heads
        else:
            nheads = d_in // head_dim
        return (
            (
                2
                * batch_size
                * seq_len
                * hidden_size
                * (2 * d_in + 2 * num_groups * state_dim + nheads)
            )  # in_proj
            + (7 * batch_size * seq_len * d_in * state_dim)  # scan
            + (2 * batch_size * seq_len * d_in * hidden_size)  # out_proj
        )

    def hybrid_flops(batch_size, seq_len, hidden_size,
                     num_attn_layers, num_mamba_layers, num_mlp_layers,
                     mamba_state_dim=128, mamba_head_dim=64,
                     mamba_num_groups=8, mamba_num_heads=128,
                     num_attn_heads=32,gqa=True,
                     gqa_groups=8, kv_channels=None,
                     mlp_expansion=4.0, swiglu=False,
                     vocab_size=256000):
        """Calculate total FLOPs for the hybrid model."""
        flops_fwd = (
                num_attn_layers * attn_layer_flops(batch_size, seq_len, hidden_size,
                                                   num_attn_heads, gqa, gqa_groups, kv_channels) +
                num_mlp_layers * mlp_layer_flops(batch_size, seq_len, hidden_size,
                                                 mlp_expansion, swiglu) +
                num_mamba_layers * mamba_layer_flops(batch_size, seq_len, hidden_size,
                                                     mamba_state_dim, mamba_head_dim,
                                                     mamba_num_groups, mamba_num_heads) +
                (2 * batch_size * seq_len * hidden_size * vocab_size)  # logits computation
        )
        return flops_fwd * 3

    def transformer_flops():
        """Calculate FLOPs for a standard Transformer model."""
        # TODO(helenn/dnarayanan): Refactor this to reuse the helper methods.
        # Attention projection size.
        query_projection_size = args.kv_channels * args.num_attention_heads
        query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_size
        # Group Query Attention.
        if not args.group_query_attention:
            args.num_query_groups = args.num_attention_heads
        # MoE.
        if args.num_experts is None:
            # Every Transformer MLP is dense.
            num_dense_layers = args.num_layers
            num_moe_layers = 0
            num_experts_routed_to = 0
            last_layer_is_moe = 0
        else:
            # Calculate number of dense and MoE Transformer MLPs.
            if isinstance(args.moe_layer_freq, int):
                moe_layer_pattern = [
                    1 if (i % args.moe_layer_freq == 0) else 0 for i in range(args.num_layers)
                ]
            elif isinstance(args.moe_layer_freq, list):
                moe_layer_pattern = args.moe_layer_freq
            else:
                raise RuntimeError("Illegal --moe-layer-freq argument provided!")
            assert len(moe_layer_pattern) == args.num_layers, (
                f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, "
                f"expected {args.num_layers}, "
                f"current moe layer pattern: {args.moe_layer_freq}"
            )
            num_moe_layers = sum(moe_layer_pattern)  # Number of 1s in `moe_layer_pattern`.
            num_dense_layers = args.num_layers - num_moe_layers
            num_experts_routed_to = args.moe_router_topk
            last_layer_is_moe = moe_layer_pattern[-1]

        if args.mtp_num_layers is not None:
            mtp_num_layers = args.mtp_num_layers
            num_moe_layers += last_layer_is_moe * mtp_num_layers
            num_dense_layers += (1 - last_layer_is_moe) * mtp_num_layers
            num_layers = args.num_layers + mtp_num_layers
        else:
            mtp_num_layers = 0
            num_layers = args.num_layers

        moe_ffn_hidden_size = (
            args.moe_ffn_hidden_size
            if args.moe_ffn_hidden_size is not None
            else args.ffn_hidden_size
        )
        shared_expert_ffn_hidden_size = (
            0
            if args.moe_shared_expert_intermediate_size is None
            else args.moe_shared_expert_intermediate_size
        )
        # SwiGLU.
        gated_linear_multiplier = 3 / 2 if args.swiglu else 1

        # The 12x term below comes from the following factors; for more details, see
        # "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473.
        # - 3x: Each GEMM in the model needs to be performed 3 times (forward pass,
        #       backward wgrad [weight gradient], backward dgrad [data gradient]).
        # - 2x: GEMMs of a particular size are stacked twice in the standard Transformer model
        #       architectures implemented in this codebase (e.g., h->ffn_h GEMM and ffn_h->h GEMM
        #       in MLP layer).
        # - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
        expansion_factor = 3 * 2 * 2

        if args.multi_latent_attention:
            assert not args.group_query_attention
            '''
            Basic arithmetic
            let B is batch size, s is seq_len, h is embedding dim,
            for one self_attnetion block (prenorm is not included)
            qkv projection:  6Bsh^2
            attn:            2Bs^2h
            attn over value: 2Bs^2h
            oproj:           2Bsh^2

            references
            https://arxiv.org/abs/2305.10403
            https://arxiv.org/abs/2205.05198
            '''
            ## MLA
            if args.q_lora_rank is None:
                q_term = (
                    args.hidden_size
                    * args.num_attention_heads
                    * (args.qk_head_dim + args.qk_pos_emb_head_dim)
                )
            else:
                q_term = args.q_lora_rank * (
                    args.hidden_size
                    + args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim)
                    + 1
                )
            self_attn_term = (
                3
                * 2  # fwd(1) + bwd(2) *FMA
                * num_layers
                * (
                    ## q lora + rope + q norm
                    q_term
                    ## kv lora + rope + kv norm
                    + args.kv_lora_rank
                    * (
                        args.hidden_size
                        + args.num_attention_heads * (args.qk_head_dim + args.v_head_dim)
                        + 1
                    )
                    + args.hidden_size * args.qk_pos_emb_head_dim
                    ## o proj
                    + (args.num_attention_heads * args.v_head_dim) * args.hidden_size
                    ## core attn
                    + args.seq_length
                    * (args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim))
                    / 2
                    + args.seq_length * args.num_attention_heads * args.v_head_dim / 2
                )
            )

        else:
            ## MHA or GQA
            self_attn_term = (
                expansion_factor
                * num_layers
                * args.hidden_size
                * args.hidden_size
                * (
                    (
                        1
                        + (args.num_query_groups / args.num_attention_heads)
                        # # Only half of the attention matrix is non-zero and needs to be multiplied with V.
                        + (args.seq_length / args.hidden_size / 2)
                    )
                    * query_projection_to_hidden_size_ratio
                )
            )

        total_floating_point_operations = (
            batch_size
            * args.seq_length
            * (
                # MLP
                expansion_factor
                * num_layers
                * args.hidden_size
                * (
                    # dense layer (deepseek v2, v3 style)
                    (args.ffn_hidden_size * gated_linear_multiplier)
                    * (num_dense_layers / num_layers)
                    # routed experts
                    + (moe_ffn_hidden_size * num_experts_routed_to * gated_linear_multiplier)
                    * (num_moe_layers / num_layers)
                    # Shared Experts.
                    + (shared_expert_ffn_hidden_size * gated_linear_multiplier)
                    * (num_moe_layers / num_layers)
                )
                # Self Attention
                + self_attn_term
                # MTP norms and proj
                + 3
                * 2
                * mtp_num_layers
                * (
                    # MTP eh norm + final nrom
                    3 * args.hidden_size
                    # MTH eh proj
                    + 2 * args.hidden_size * args.hidden_size
                )
                # Logit.
                + 3 * 2 * args.hidden_size * args.padded_vocab_size * (mtp_num_layers + 1)
            )
        )
        return total_floating_point_operations

    # Main entrypoint for FLOPs calculation.
    if args.is_hybrid_model:
        # Calculate the number of each type of layer.
        num_attn_layers, num_mamba_layers, num_mlp_layers = calculate_layer_counts()

        # Compute hybrid model FLOPs.
        return hybrid_flops(
            batch_size=batch_size,
            seq_len=args.seq_length,
            hidden_size=args.hidden_size,
            num_attn_layers=num_attn_layers,
            num_mamba_layers=num_mamba_layers,
            num_mlp_layers=num_mlp_layers,
            mamba_state_dim=args.mamba_state_dim,
            mamba_head_dim=args.mamba_head_dim,
            mamba_num_groups=args.mamba_num_groups,
            mamba_num_heads=args.mamba_num_heads,
            num_attn_heads=args.num_attention_heads,
            gqa=args.group_query_attention,
            gqa_groups=args.num_query_groups,
            kv_channels=args.kv_channels,
            mlp_expansion=args.ffn_hidden_size / args.hidden_size,
            swiglu=args.swiglu,
            vocab_size=args.padded_vocab_size,
        )
    else:
        # Compute standard Transformer model FLOPs.
        return transformer_flops()


def get_start_time_from_progress_log():
    """
    Gets start time of earliest job with same world size. Also returns the number
    of floating-point operations completed in last saved checkpoint.
    """
    args = get_args()
    assert args.save is not None
    progress_log_filename = os.path.join(args.save, "progress.txt")

    # start_time is time when job with same world size started.
    # start_num_floating_point_operations is the number of floating-point operations
    # completed when this job started.
    # latest_num_floating_point_operations is the number of floating-point operations
    # completed in most recent saved checkpoint.
    start_time = None
    start_num_floating_point_operations = None
    latest_num_floating_point_operations = 0

    def _get_field(string, type):
        return type(string.split(': ')[1])

    with open_file(progress_log_filename, 'r') as f:
        for line in f:
            line = line.strip()
            line_tokens = line.split('\t')
            world_size_in_line = _get_field(line_tokens[2], int)
            if line_tokens[3] == "Saved checkpoint":
                latest_num_floating_point_operations = _get_field(line_tokens[7], float)
            if world_size_in_line != args.world_size:
                # Re-start search if we see a different world size.
                start_time = None
                start_num_floating_point_operations = None
                continue
            if line_tokens[3] == "Starting job":
                if start_time is None:
                    start_time = line_tokens[0]
                    start_num_floating_point_operations = latest_num_floating_point_operations
    assert (
        start_time is not None and start_num_floating_point_operations is not None
    ), "Should have seen at least one 'Starting job' entry with same world_size"
    return datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S'), start_num_floating_point_operations


def preprocess_common_state_dict(common_state_dict):
    import copy

    # Convert args key of type namespace to dictionary
    preprocessed_common_state_dict = copy.deepcopy(common_state_dict)
    preprocessed_common_state_dict['args'] = vars(preprocessed_common_state_dict['args'])
    # Remove rank and local rank from state dict if it exists, since they are expected to be different
    preprocessed_common_state_dict['args'].pop('local_rank', None)
    preprocessed_common_state_dict['args'].pop('rank', None)
    if (
        preprocessed_common_state_dict['args']['use_distributed_optimizer']
        and "optimizer" in preprocessed_common_state_dict
    ):
        def reorder_inner_param_groups(optimizer_state_dict):
            # When distributed optimizer loading, source param groups will be reordered,
            # so we reorder the param groups here to prevent warning.

            # Pop empty param_state.
            if "param_state" in optimizer_state_dict and not optimizer_state_dict["param_state"]:
                optimizer_state_dict.pop("param_state")

            # Reorder param groups.
            if "optimizer" not in optimizer_state_dict:
                return
            inner_optimizer = optimizer_state_dict["optimizer"]
            if "param_groups" not in inner_optimizer:
                return
            param_groups = inner_optimizer["param_groups"]
            key_fn = lambda pg: [pg[key] for key in param_group_identifier_keys]
            param_groups.sort(key=key_fn)
            inner_optimizer["param_groups"] = param_groups

        optimizer_state_dict = preprocessed_common_state_dict['optimizer']
        if "optimizer" in optimizer_state_dict:
            # Only 1 optimizer in chained optimizer.
            reorder_inner_param_groups(optimizer_state_dict)
        else:
            # Multiple optimizers in chained optimizer.
            for i in range(len(optimizer_state_dict)):
                if i in optimizer_state_dict.keys():
                    reorder_inner_param_groups(optimizer_state_dict[i])

    return preprocessed_common_state_dict


def pretrain(
    train_valid_test_dataset_provider,
    model_provider,
    model_type,
    forward_step_func,
    process_non_loss_data_func=None,
    extra_args_provider=None,
    args_defaults={},
    get_embedding_ranks=None,
    get_position_embedding_ranks=None,
    non_loss_data_func=None,
    store=None,
    inprocess_call_wrapper: Optional[CallWrapper] = None,
):
    """Main training program.

    This function will run the followings in the order provided:
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
        3) call train_val_test_data_provider to get train/val/test datasets.
        4) train the model using the forward_step_func.

    Args:
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        model_type: an enum that specifies the type of model being trained.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        process_non_loss_data_func: a function to post process outputs of the
            network. It can be used for dumping output tensors (e.g images) to
            tensorboard. It takes `collected data`(list of tensors),
            `current iteration index` and `tensorboard writer` as arguments.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
        get_embedding_ranks (TODO):
        get_position_embedding_ranks (TODO):
        non_loss_data_func (callable): A custom function to call during evaluation.
            It can run e.g. benchmarks.
        store: an optional instance of torch.distributed.Store, to be used by
            torch.distributed.init_process_group
        inprocess_call_wrapper: an optional instance of inprocess.CallWrapper,
            it is automatically injected when in-process restart is in use
    """

    if inprocess_call_wrapper is not None:
        iteration = inprocess_call_wrapper.iteration
        store = torch.distributed.PrefixStore(str(iteration), store)

    # Initalize and get arguments, timers, and Tensorboard writer.
    initialize_megatron(
        extra_args_provider=extra_args_provider,
        args_defaults=args_defaults,
        get_embedding_ranks=get_embedding_ranks,
        get_position_embedding_ranks=get_position_embedding_ranks,
        store=store,
    )

    args = get_args()
    timers = get_timers()

    if args.log_progress:
        append_to_progress_log("Starting job")

    # Initialize fault tolerance
    # NOTE: ft_integration functions other than `setup` are no-op if the FT is not initialized
    if args.enable_ft_package:
        ft_integration.setup(args)
        ft_integration.maybe_setup_simulated_fault()

    # Set pytorch JIT layer fusion options and warmup JIT functions.
    set_jit_fusion_options()

    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
    start_time_tensor = torch.tensor([_TRAIN_START_TIME], dtype=torch.double, device='cuda')
    torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()

    app_metrics = {}
    app_metrics['app_start_time'] = round(_TRAIN_START_TIME * 1000.0)
    app_metrics['app_model_init_start_time'] = round(_TRAIN_START_TIME * 1000.0)

    print_rank_0(
        'time to initialize megatron (seconds): {:.3f}'.format(time.time() - _TRAIN_START_TIME)
    )
    print_datetime('after megatron is initialized')
    app_metrics['app_model_init_finish_time'] = one_logger_utils.get_timestamp_in_ms()

    # Track E2E metrics on pretrain start
    one_logger_utils.on_pretrain_start()

    # Context used for persisting some state between checkpoint saves.
    if args.non_persistent_ckpt_type == 'local':
        try:
            from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import (
                LocalCheckpointManager,
            )
            from nvidia_resiliency_ext.checkpointing.local.replication.group_utils import (
                parse_group_sequence,
                GroupWrapper,
            )
            from nvidia_resiliency_ext.checkpointing.local.replication.strategies import (
                CliqueReplicationStrategy,
            )
        except ModuleNotFoundError:
            raise RuntimeError(
                "The 'nvidia_resiliency_ext' module is required for local "
                "checkpointing but was not found. Please ensure it is installed."
            )

        if args.replication:
            repl_strategy = CliqueReplicationStrategy.from_replication_params(
                args.replication_jump, args.replication_factor
            )
        else:
            repl_strategy = None

        checkpointing_context = {
            'local_checkpoint_manager': LocalCheckpointManager(
                args.non_persistent_local_ckpt_dir, repl_strategy=repl_strategy
            )
        }
    else:
        checkpointing_context = {}

    # Model, optimizer, and learning rate.
    timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
    model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
        model_provider, model_type, checkpointing_context=checkpointing_context
    )

    timers('model-and-optimizer-setup').stop()
    print_datetime('after model, optimizer, and learning rate ' 'scheduler are built')
    config = get_model_config(model[0])

    # Data stuff.
    app_metrics['app_build_dataiters_start_time'] = one_logger_utils.get_timestamp_in_ms()
    timers('train/valid/test-data-iterators-setup', log_level=0).start(barrier=True)
    if args.virtual_pipeline_model_parallel_size is not None:
        train_data_iterator = []
        valid_data_iterator = []
        test_data_iterator = []
        for vp_stage in range(len(model)):
            dataset_provider_parameters = inspect.signature(train_valid_test_dataset_provider).parameters
            assert "vp_stage" in dataset_provider_parameters, \
                "vp_stage must be a kwarg in train_valid_test_dataset_provider when using virtual pipeline parallelism"
            vp_stage_train_valid_test_dataset_provider = \
                functools.partial(train_valid_test_dataset_provider, vp_stage=vp_stage)
            if getattr(train_valid_test_dataset_provider, 'is_distributed', False):
                vp_stage_train_valid_test_dataset_provider.is_distributed = True
            iterators = build_train_valid_test_data_iterators(
                vp_stage_train_valid_test_dataset_provider
            )
            train_data_iterator.append(iterators[0])
            valid_data_iterator.append(iterators[1])
            test_data_iterator.append(iterators[2])
    else:
        train_data_iterator, valid_data_iterator, test_data_iterator = (
            build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
        )
    timers('train/valid/test-data-iterators-setup').stop()
    print_datetime('after dataloaders are built')
    app_metrics['app_build_dataiters_finish_time'] = one_logger_utils.get_timestamp_in_ms()

    # Track if training is enabled. Can only be done once args.do_train is assigned after dataloader is built.
    one_logger_utils.track_config_flags(
        args.train_iters,
        args.skip_train,
        args.do_train,
        args.do_valid,
        args.do_test,
        args.dataloader_type,
        args.retro_project_dir,
        args.retro_cyclic_train_iters,
    )

    # Print setup timing.
    print_rank_0('done with setup ...')
    timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'], barrier=True)

    one_logger = get_one_logger()
    one_logger and one_logger.log_metrics(app_metrics)

    wandb_writer = get_wandb_writer()
    if wandb_writer:
        # Add job name to the wandb config to make it easier to run more singleton dependency jobs.
        wandb_writer.config.update({'slurm_job_name': os.getenv("SLURM_JOB_NAME", "N/A")})

    if not args.skip_train:
        print_rank_0('training ...')

        if args.dataloader_type == 'cyclic' and args.retro_project_dir:
            assert args.retro_cyclic_train_iters is not None
            args.train_iters = args.retro_cyclic_train_iters
            print_rank_0("retro cyclic train iters : %d" % args.train_iters)

        iteration = 0
        if args.do_train and args.train_iters > 0:
            iteration, num_floating_point_operations_so_far = train(
                forward_step_func,
                model,
                optimizer,
                opt_param_scheduler,
                train_data_iterator,
                valid_data_iterator,
                process_non_loss_data_func,
                config,
                checkpointing_context,
                non_loss_data_func,
            )

        print_datetime('after training is done')

        if args.save and iteration != 0 and iteration % args.save_interval != 0:
            save_checkpoint(
                iteration,
                model,
                optimizer,
                opt_param_scheduler,
                num_floating_point_operations_so_far,
                checkpointing_context,
                train_data_iterator=train_data_iterator,
                preprocess_common_state_dict_fn=preprocess_common_state_dict,
            )

        one_logger and one_logger.log_metrics(
            {'app_train_loop_finish_time': one_logger_utils.get_timestamp_in_ms()}
        )

    else:
        print_rank_0('skipping training (--skip-train is on) ...')

        iteration = args.iteration

    if args.do_valid:
        prefix = f'iteration {iteration} on validation set'
        if getattr(args, 'perform_rl_step', False):
            rl_utils.evaluate_and_print_results_rl(
                valid_data_iterator, model, optimizer,
                iteration, write_to_tensorboard=not args.skip_train
            )
        else:
            evaluate_and_print_results(
                prefix, forward_step_func,
                valid_data_iterator, model,
                iteration, process_non_loss_data_func, config,
                verbose=True, write_to_tensorboard=not args.skip_train,
                non_loss_data_func=non_loss_data_func
            )

    if args.do_test:
        prefix = f'iteration {iteration} on test set'
        evaluate_and_print_results(
            prefix,
            forward_step_func,
            test_data_iterator,
            model,
            iteration,
            process_non_loss_data_func,
            config,
            verbose=True,
            write_to_tensorboard=not args.skip_train,
            non_loss_data_func=non_loss_data_func,
        )

    wandb_writer = get_wandb_writer()
    if wandb_writer:
        wandb_writer.finish()

    ft_integration.on_checkpointing_start()
    maybe_finalize_async_save(blocking=True, terminate=True)
    ft_integration.on_checkpointing_end(is_async_finalization=True)

    one_logger and one_logger.log_metrics(
        {'app_finish_time': one_logger_utils.get_timestamp_in_ms()}
    )

    ft_integration.shutdown()
    one_logger_utils.finish()


def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while (
            consumed_samples <= int(args.rampup_batch_size[2])
            and consumed_samples <= args.train_samples
        ):
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
            iterations += 1
        # Reset
        update_num_microbatches(0, consistency_check=False)
        # Constant phase
        # Note that we throw away any partial last batch.
        if args.train_samples > consumed_samples:
            iterations += (args.train_samples - consumed_samples) // args.global_batch_size
        args.train_iters = iterations

    print_rank_0(f'setting training iterations to {args.train_iters}')


def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
    """Build the model."""
    args = get_args()
    args.model_type = model_type

    # Build model.
    def build_model():
        if (
            mpu.get_pipeline_model_parallel_world_size() > 1
            and args.virtual_pipeline_model_parallel_size is not None
        ):
            model = []
            for i in range(args.virtual_pipeline_model_parallel_size):
                # Set pre_process and post_process only after virtual rank is set.
                pre_process = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i)
                post_process = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i)
                this_model = model_provider_func(
                    pre_process=pre_process, post_process=post_process, vp_stage=i)
                this_model.model_type = model_type
                this_model.vp_stage = i
                model.append(this_model)
        else:
            pre_process = mpu.is_pipeline_first_stage()
            post_process = mpu.is_pipeline_last_stage()
            model = model_provider_func(pre_process=pre_process, post_process=post_process)
            model.model_type = model_type
        return model

    if args.init_model_with_meta_device:
        with torch.device('meta'):
            model = build_model()
    else:
        model = build_model()

    if not isinstance(model, list):
        model = [model]

    # Set tensor model parallel attributes if not set.
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
    for model_module in model:
        for param in model_module.parameters():
            tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)

    # Print number of parameters.
    num_parameters = sum(
        [sum([p.nelement() for p in model_module.parameters()]) for model_module in model]
    )
    if mpu.get_data_parallel_rank() == 0 and mpu.get_context_parallel_rank() == 0:
        print(
            ' > number of parameters on (tensor, pipeline) '
            'model parallel rank ({}, {}): {}'.format(
                mpu.get_tensor_model_parallel_rank(),
                mpu.get_pipeline_model_parallel_rank(),
                num_parameters,
            ),
            flush=True,
        )

    # GPU allocation.
    # For FSDP2, we don't allocate GPU memory here. We allocate GPU memory
    # in the fully_shard function of FSDP2 instead.
    if (
        not (args.use_torch_fsdp2 and args.use_cpu_initialization)
        and not args.init_model_with_meta_device
    ):
        for model_module in model:
            model_module.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16 or args.bf16:
        config = get_model_config(model[0])
        model = [Float16Module(config, model_module) for model_module in model]

    # Materialize tensors on meta device (GPU allocation) if not using FSDP2.
    if args.init_model_with_meta_device and not args.use_torch_fsdp2:
        #for model_module in model:
        model = [to_empty_if_meta_device(model_module, device=torch.device("cuda")) for model_module in model]




    # Before TE2.x: The model_module.bfloat16()/model_module.half() above will call the inplace
    #               copy of TE's Float8Tensor, which will write an unwanted value (amax calculated
    #               from the current fp8 param) to its amax_history. The below function will correct
    #               the amax_history back.
    # After TE2.x: Below function is an empty function and does nothing.
    correct_amax_history_if_needed(model)

    if wrap_with_ddp:
        if args.use_torch_fsdp2:
            assert HAVE_FSDP2, "Torch FSDP2 requires torch>=2.4.0"
            DP = torch_FSDP
        elif args.use_megatron_fsdp:
            DP = megatron_FSDP
        else:
            DP = DDP

        config = get_model_config(model[0])

        if getattr(args, "use_torch_fsdp2", False):
            reshard_after_forward = getattr(args, "torch_fsdp2_reshard_after_forward", True)
            ddp_config = TorchFullyShardedDataParallelConfig(reshard_after_forward=reshard_after_forward)
        else:
            kwargs = {}
            for f in dataclasses.fields(DistributedDataParallelConfig):
                if hasattr(args, f.name):
                    kwargs[f.name] = getattr(args, f.name)
            kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32
            kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad
            kwargs['check_for_large_grads'] = args.check_for_large_grads
            if args.ddp_num_buckets is not None:
                assert args.ddp_bucket_size is None, \
                    "Cannot specify both --ddp-num-buckets and --ddp-bucket-size"
                assert args.ddp_num_buckets > 0, \
                    "--ddp-num-buckets must be greater than 0"
                kwargs['bucket_size'] = num_parameters // args.ddp_num_buckets
            else:
                kwargs['bucket_size'] = args.ddp_bucket_size
            kwargs['pad_buckets_for_high_nccl_busbw'] = args.ddp_pad_buckets_for_high_nccl_busbw
            kwargs['average_in_collective'] = args.ddp_average_in_collective
            if args.use_megatron_fsdp and args.use_precision_aware_optimizer:
                kwargs["preserve_fp32_weights"] = False
            ddp_config = DistributedDataParallelConfig(**kwargs)

            # In the Megatron FSDP and DDP use path, we need to initialize the bucket size.
            # If bucket_size is not provided as an input, use sane default.
            # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
            # ring-reduce implementations are large enough to remain bandwidth-bound rather than
            # latency-bound.
            if ddp_config.bucket_size is None:
                ddp_config.bucket_size = max(
                    40000000, 1000000 * mpu.get_data_parallel_world_size(with_context_parallel=True)
                )
            # Set bucket_size to infinity if overlap_grad_reduce is False.
            if not ddp_config.overlap_grad_reduce:
                ddp_config.bucket_size = None

        with torch.cuda.stream(torch.cuda.Stream()):
            model = [
                DP(
                    config=config,
                    ddp_config=ddp_config,
                    module=model_chunk,
                    # Turn off bucketing for model_chunk 2 onwards, since communication for these
                    # model chunks is overlapped with compute anyway.
                    disable_bucketing=(model_chunk_idx > 0)
                    or args.overlap_param_gather_with_optimizer_step,
                )
                for (model_chunk_idx, model_chunk) in enumerate(model)
            ]

        # Broadcast params from data parallel src rank to other data parallel ranks.
        if args.data_parallel_random_init:
            for model_module in model:
                model_module.broadcast_params()

    return model


def get_optimizer_param_scheduler(optimizer):
    """Build the learning rate scheduler."""
    args = get_args()

    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        lr_decay_steps = args.lr_decay_iters * args.global_batch_size
        wd_incr_steps = args.train_iters * args.global_batch_size
        wsd_decay_steps = None
        if args.lr_wsd_decay_iters is not None:
            wsd_decay_steps = args.lr_wsd_decay_iters * args.global_batch_size
        if args.lr_warmup_fraction is not None:
            lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
        else:
            lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
        update_train_iters(args)
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        lr_decay_steps = args.lr_decay_samples
        wd_incr_steps = args.train_samples
        wsd_decay_steps = args.lr_wsd_decay_samples
        if args.lr_warmup_fraction is not None:
            lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
        else:
            lr_warmup_steps = args.lr_warmup_samples
    else:
        raise Exception('either train-iters or train-samples should be provided.')

    opt_param_scheduler = OptimizerParamScheduler(
        optimizer,
        init_lr=args.lr_warmup_init,
        max_lr=args.lr,
        min_lr=args.min_lr,
        lr_warmup_steps=lr_warmup_steps,
        lr_decay_steps=lr_decay_steps,
        lr_decay_style=args.lr_decay_style,
        start_wd=args.start_weight_decay,
        end_wd=args.end_weight_decay,
        wd_incr_steps=wd_incr_steps,
        wd_incr_style=args.weight_decay_incr_style,
        use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler,
        override_opt_param_scheduler=args.override_opt_param_scheduler,
        wsd_decay_steps=wsd_decay_steps,
        lr_wsd_decay_style=args.lr_wsd_decay_style,
    )

    return opt_param_scheduler


def setup_model_and_optimizer(
    model_provider_func,
    model_type,
    no_wd_decay_cond=None,
    scale_lr_cond=None,
    lr_mult=1.0,
    checkpointing_context=None,
):
    """Setup model and optimizer."""
    args = get_args()
    timers = get_timers()
    one_logger = get_one_logger()

    model = get_model(model_provider_func, model_type)
    unwrapped_model = unwrap_model(model)

    one_logger and one_logger.log_metrics({"app_build_optimzer_start_time": one_logger_utils.get_timestamp_in_ms()})
    kwargs = {}
    for f in dataclasses.fields(OptimizerConfig):
        if hasattr(args, f.name):
            kwargs[f.name] = getattr(args, f.name)
    config = OptimizerConfig(**kwargs)
    config.timers = timers
    optimizer = get_megatron_optimizer(
        config,
        model,
        no_wd_decay_cond,
        scale_lr_cond,
        lr_mult,
        use_gloo_process_groups=args.enable_gloo_process_groups,
        # If the user is asking for a non-zero embedding init std, skip weight decay for embeddings
        #  to avoid embeddings from shrinking to zero as recommended in https://arxiv.org/abs/2312.16903
        default_skip_embedding_weight_decay=args.embedding_init_method_std is not None,
    )
    opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
    one_logger and one_logger.log_metrics({"app_build_optimzer_finish_time": one_logger_utils.get_timestamp_in_ms()})

    if args.moe_use_upcycling:
        torch.distributed.barrier()
        assert not checkpoint_exists(args.save), (
            "The upcycling destination directory already exists. "
            "Please check if --moe-use-upcycling is mistakenly enabled. "
            "Upcycling should only be set for the first run when converting the dense model. "
            "All subsequent runs should remove this flag. "
        )
        # before changing moe related global args, save them in local variables
        num_experts = args.num_experts
        expert_model_parallel_size = args.expert_model_parallel_size
        moe_ffn_hidden_size = args.ffn_hidden_size

        # set dense model related args in to global args before getting dense model
        args.num_experts = None
        args.expert_model_parallel_size = 1
        args.ffn_hidden_size = moe_ffn_hidden_size * args.moe_upcycling_granularity 

        # get dense model
        dense_model_for_upcycling = get_model(model_provider_func, model_type)

        # recover moe upcycling related args in global args before executing upcycling
        args.num_experts = num_experts
        args.expert_model_parallel_size = expert_model_parallel_size
        args.ffn_hidden_size = moe_ffn_hidden_size

        # execute upcycling
        _, args.num_floating_point_operations_so_far = upcycling_utils.load_and_upcycle_model(
            load_checkpoint,
            unwrapped_model,
            dense_model_for_upcycling,
            load_kwargs={
                'model': dense_model_for_upcycling,
                'optimizer': None,
                'opt_param_scheduler': None,
            },
        )
        args.iteration = 1
        save_checkpoint(
            args.iteration, model, None, None, args.num_floating_point_operations_so_far
        )
        torch.distributed.barrier()
        del dense_model_for_upcycling
        if (args.fp16 or args.bf16) and optimizer is not None:
            optimizer.reload_model_params()
        print_rank_0(f'Upcycled checkpoint saved to {args.save}')

    if (
        args.load is not None or args.pretrained_checkpoint is not None
    ) and not args.moe_use_upcycling:
        one_logger and one_logger.log_metrics(
            {'load_checkpoint_start_time': one_logger_utils.get_timestamp_in_ms()}
        )
        timers('load-checkpoint', log_level=0).start(barrier=True)

        args.iteration, args.num_floating_point_operations_so_far = load_checkpoint(
            model,
            optimizer,
            opt_param_scheduler,
            checkpointing_context=checkpointing_context,
            skip_load_to_model_and_opt=HAVE_FSDP2
            and getattr(args, "use_torch_fsdp2", False)
            and args.ckpt_format == "torch_dist",
        )
        timers('load-checkpoint').stop(barrier=True)
        timers.log(['load-checkpoint'])
        one_logger and one_logger.log_metrics(
            {
                'load_checkpoint_finish_time': one_logger_utils.get_timestamp_in_ms(),
                'load_checkpoint_time': timers('load-checkpoint').active_time(),
            }
        )
    else:
        args.iteration = 0
        args.num_floating_point_operations_so_far = 0

    # get model without FP16 and/or DDP wrappers
    if (
        args.iteration == 0
        and len(unwrapped_model) == 1
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert')
    ):
        print_rank_0("Initializing ICT from pretrained BERT model")
        unwrapped_model[0].init_state_dict_from_bert()
        if args.fp16:
            optimizer.reload_model_params()

    # Convert checkpoint format.
    if args.ckpt_convert_format is not None:
        load_ckpt_format = args.ckpt_format
        args.ckpt_format = args.ckpt_convert_format
        args.save = os.path.join(args.ckpt_convert_save, args.ckpt_convert_format)
        update_use_dist_ckpt(args)

        save_checkpoint(
            args.iteration,
            model,
            optimizer,
            opt_param_scheduler,
            args.num_floating_point_operations_so_far,
            preprocess_common_state_dict_fn=preprocess_common_state_dict,
        )

        print_rank_0("> converted checkpoint: %s -> %s." % (load_ckpt_format, args.ckpt_format))
        torch.distributed.barrier()
        exit()

    return model, optimizer, opt_param_scheduler


def dummy_train_step(data_iterator):
    """Single dummy training step."""
    num_microbatches = get_num_microbatches()
    rerun_state_machine = get_rerun_state_machine()
    while rerun_state_machine.should_run_forward_backward(data_iterator):
        for _ in range(num_microbatches):
            # Re-use methods used in get_batch() from pretrain_{gpt, mamba}.py.
            batch = get_batch_on_this_tp_rank(data_iterator)
            batch = get_batch_on_this_cp_rank(batch)


def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    rerun_state_machine = get_rerun_state_machine()
    while rerun_state_machine.should_run_forward_backward(data_iterator):
        # Set grad to zero.
        for model_chunk in model:
            model_chunk.zero_grad_buffer()
        optimizer.zero_grad()

        if has_nvidia_modelopt:
            # [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensors
            adjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation(
                model, args.seq_length, args.micro_batch_size, args.decoder_seq_length
            )
        else:
            adjust_tensor_shapes_fn = None

        # For the mxfp8_param with reuse_grad_buf_for_mxfp8_param_ag and dp_ag_overlap,
        # we need to call the _copy_main_params_to_param_buffer() after the grad buffer
        # is zeroed by zero_grad_buffer() because param and grad buffer are shared.
        if args.reuse_grad_buf_for_mxfp8_param_ag and args.overlap_param_gather:
            for optim_instance in optimizer.chained_optimizers:
                if isinstance(optim_instance, DistributedOptimizer):
                    optim_instance._copy_main_params_to_param_buffer()

        # Forward pass.
        losses_reduced = forward_backward_func(
            forward_step_func=forward_step_func,
            data_iterator=data_iterator,
            model=model,
            num_microbatches=get_num_microbatches(),
            seq_length=args.seq_length,
            micro_batch_size=args.micro_batch_size,
            decoder_seq_length=args.decoder_seq_length,
            forward_only=False,
            adjust_tensor_shapes_fn=adjust_tensor_shapes_fn,
        )
    should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit()
    if should_exit:
        return {}, True, should_checkpoint, should_exit, exit_code, None, None

    # Empty unused memory.
    if args.empty_unused_memory_level >= 1:
        torch.cuda.empty_cache()

    # Vision gradients.
    if args.vision_pretraining and args.vision_pretraining_type == "dino":
        unwrapped_model = unwrap_model(model[0])
        unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)

    # Update parameters.

    timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
    timers('optimizer').stop()

    # when freezing sub-models we may have a mixture of successful and unsucessful ranks,
    # so we must gather across mp ranks
    update_successful = logical_and_across_model_parallel_group(update_successful)
    # grad_norm and num_zeros_in_grad will be None on ranks without trainable params,
    # so we must gather across mp ranks
    grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm)
    if args.log_num_zeros_in_grad:
        num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad)

    # Vision momentum.
    if args.vision_pretraining and args.vision_pretraining_type == "dino":
        unwrapped_model = unwrap_model(model[0])
        unwrapped_model.update_momentum(args.curr_iteration)

    # Update learning rate.
    if update_successful:
        increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size
        opt_param_scheduler.step(increment=increment)
        skipped_iter = 0
    else:
        skipped_iter = 1

    # Empty unused memory.
    if args.empty_unused_memory_level >= 2:
        torch.cuda.empty_cache()

    if mpu.is_pipeline_last_stage(ignore_virtual=True):
        # Average loss across microbatches.
        loss_reduced = {}

        for key in losses_reduced[0].keys():
            val = [x[key].view(-1) for x in losses_reduced]
            if val[0].numel() == 2:
                if args.sft:
                    # in mcore the normalization happens on micro batch instead of global
                    val = torch.vstack(val)
                    val = val[:, 0] / val[:, 1]
                    val = val.mean()
                    torch.distributed.all_reduce(
                        val,
                        group=mpu.get_data_parallel_group(with_context_parallel=True)
                    )
                    val /= torch.distributed.get_world_size(
                        group=mpu.get_data_parallel_group(with_context_parallel=True)
                    )
                    loss_reduced[key] = val
                else:
                    # there is one dict per microbatch. in new reporting, we average
                    # over the total number of tokens across the global batch.
                    val = torch.vstack(val).sum(dim=0)
                    torch.distributed.all_reduce(
                        val,
                        group=mpu.get_data_parallel_group(with_context_parallel=True)
                    )
                    loss_reduced[key] = val[0] / val[1]
            elif val[0].numel() == 1:
                # legacy behavior, we average over the number of microbatches
                val = torch.cat(val).mean()
                loss_reduced[key] = val
            else:
                raise ValueError(f"Invalid value shape: {val[0].shape} for key {key}")
        return (
            loss_reduced,
            skipped_iter,
            should_checkpoint,
            should_exit,
            exit_code,
            grad_norm,
            num_zeros_in_grad,
        )
    return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad


def training_log(
    loss_dict,
    total_loss_dict,
    learning_rate,
    decoupled_learning_rate,
    iteration,
    loss_scale,
    report_memory_flag,
    skipped_iter,
    grad_norm,
    params_norm,
    num_zeros_in_grad,
    moe_relu_sparsity=None,
    moe_relu_l1_reg_coeff=None,
):
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
    wandb_writer = get_wandb_writer()
    one_logger = get_one_logger()
    energy_monitor = get_energy_monitor()

    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
    skipped_iters_key = 'skipped iterations'
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(skipped_iters_key, 0) + skipped_iter
    # Update losses and set nan iterations
    got_nan = False
    for key in loss_dict:
        if not skipped_iter:
            total_loss_dict[key] = (
                total_loss_dict.get(key, torch.tensor([0.0], dtype=torch.float, device='cuda'))
                + loss_dict[key]
            )
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or value == -float('inf') or value != value
            got_nan = got_nan or is_nan
    total_loss_dict[nan_iters_key] = total_loss_dict.get(nan_iters_key, 0) + int(got_nan)

    # Logging.
    timers_to_log = [
        'forward-backward',
        'forward-compute',
        'backward-compute',
        'batch-generator',
        'forward-recv',
        'forward-send',
        'backward-recv',
        'backward-send',
        'forward-send-forward-recv',
        'forward-send-backward-recv',
        'backward-send-forward-recv',
        'backward-send-backward-recv',
        'forward-backward-send-forward-backward-recv',
        'layernorm-grads-all-reduce',
        'embedding-grads-all-reduce',
        'all-grads-sync',
        'params-all-gather',
        'optimizer-copy-to-main-grad',
        'optimizer-unscale-and-check-inf',
        'optimizer-clip-main-grad',
        'optimizer-count-zeros',
        'optimizer-inner-step',
        'optimizer-copy-main-to-model-params',
        'optimizer',
    ]
    # Add timers from RL loop if needed.
    if getattr(args, 'perform_rl_step', False):
        timers_to_log.extend(['rollout-collection', 'inference-setup', 'collect-rollouts', 'postrollout-gc-collect',
                              'sync-rollouts', 'prepare-data-for-update', 'compute-group-stats',
                              'prepare-trajectories', 'get-ltor-masks-and-position-ids', 'create-logprobs-dataloader',
                              'compute-logprobs', 'compute-ref-logprobs', 'compute-prob-stats',
                              'prepare-advantages', 'create-dataloader', 'log-wandb-tb',
                              'offload-optimizer-before-inference', 'onload-kv-cache-before-inference',
                              'wait-for-decode-only', 'build-cuda-graphs', 'suspend-engine',
                              'offload-kv-cache-after-inference', 'onload-optimizer-after-inference'])

    # Calculate batch size.
    batch_size = args.micro_batch_size * args.data_parallel_size * get_num_microbatches()

    # Track app tag & app tag ID
    one_logger_utils.track_app_tag(batch_size, args.world_size, args.seq_length)

    total_iterations = total_loss_dict[advanced_iters_key] + total_loss_dict[skipped_iters_key]

    # learning rate will be None on ranks without trainable params, so we must gather across mp ranks
    learning_rate = reduce_max_stat_across_model_parallel_group(learning_rate)
    # Tensorboard values.
    if writer and (iteration % args.tensorboard_log_interval == 0):
        if wandb_writer:
            wandb_writer.log({'samples vs steps': args.consumed_train_samples}, iteration)
        writer.add_scalar('learning-rate', learning_rate, iteration)
        writer.add_scalar('learning-rate vs samples', learning_rate, args.consumed_train_samples)
        if wandb_writer:
            wandb_writer.log({'learning-rate': learning_rate}, iteration)
        if args.decoupled_lr is not None:
            writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration)
        if args.skipped_train_samples > 0:
            writer.add_scalar('skipped-train-samples', args.skipped_train_samples, iteration)
            if wandb_writer:
                wandb_writer.log({'skipped-train-samples': args.skipped_train_samples}, iteration)
        writer.add_scalar('batch-size', batch_size, iteration)
        writer.add_scalar('batch-size vs samples', batch_size, args.consumed_train_samples)
        if wandb_writer:
            wandb_writer.log({'batch-size': batch_size}, iteration)
        for key in loss_dict:
            writer.add_scalar(key, loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key], args.consumed_train_samples)
            if wandb_writer:
                wandb_writer.log({key: loss_dict[key]}, iteration)
        if args.log_loss_scale_to_tensorboard:
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale, args.consumed_train_samples)
            if wandb_writer:
                wandb_writer.log({'loss-scale': loss_scale}, iteration)
        if args.log_world_size_to_tensorboard:
            writer.add_scalar('world-size', args.world_size, iteration)
            writer.add_scalar('world-size vs samples', args.world_size, args.consumed_train_samples)
            if wandb_writer:
                wandb_writer.log({'world-size': args.world_size}, iteration)
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm, args.consumed_train_samples)
            if wandb_writer:
                wandb_writer.log({'grad-norm': grad_norm}, iteration)
        if num_zeros_in_grad is not None:
            writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
            writer.add_scalar(
                'num-zeros vs samples', num_zeros_in_grad, args.consumed_train_samples
            )
            if wandb_writer:
                wandb_writer.log({'num-zeros': num_zeros_in_grad}, iteration)
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm, args.consumed_train_samples)
            if wandb_writer:
                wandb_writer.log({'params-norm': params_norm}, iteration)
        if moe_relu_sparsity is not None:
            writer.add_scalar('moe_relu_sparsity', moe_relu_sparsity, iteration)
            writer.add_scalar('moe_relu_sparsity vs samples', moe_relu_sparsity, args.consumed_train_samples)
            if wandb_writer:
                wandb_writer.log({'moe_relu_sparsity': moe_relu_sparsity}, iteration)
        if moe_relu_l1_reg_coeff is not None:
            writer.add_scalar('moe_relu_l1_reg_coeff', moe_relu_l1_reg_coeff, iteration)
            writer.add_scalar('moe_relu_l1_reg_coeff vs samples', moe_relu_l1_reg_coeff, args.consumed_train_samples)
            if wandb_writer:
                wandb_writer.log({'moe_relu_l1_reg_coeff': moe_relu_l1_reg_coeff}, iteration)
        if getattr(args, 'perform_rl_step', False):
            grpo_collection_iteration = iteration // (args.grpo_iterations * ( ( args.grpo_samples_per_iteration )// args.global_batch_size ))
            writer.add_scalar('grpo_collection_iteration', grpo_collection_iteration, iteration)
            if wandb_writer:
                wandb_writer.log({'grpo_collection_iteration': grpo_collection_iteration}, iteration)
        if args.log_memory_to_tensorboard:
            mem_stats = torch.cuda.memory_stats()
            writer.add_scalar(
                "mem-reserved-bytes", mem_stats["reserved_bytes.all.current"], iteration
            )
            writer.add_scalar(
                "mem-allocated-bytes", mem_stats["allocated_bytes.all.current"], iteration
            )
            writer.add_scalar(
                "mem-max-allocated-bytes", mem_stats["allocated_bytes.all.peak"], iteration
            )
            writer.add_scalar("mem-allocated-count", mem_stats["allocation.all.current"], iteration)
    if args.num_experts is not None:
        moe_loss_scale = 1 / get_num_microbatches()
        # For DirVAE router, track all keys emitted by routers by passing track_names=None.
        # This ensures extended DirVAE metrics (avg_k, KL, recon, expert loads, etc.) are logged.
        if getattr(args, 'moe_router_type', None) == 'dirvae' or getattr(args, 'moe_vae_routing', False):
            dirvae_track_names = None # log all metrics
            setattr(args, '_dirvae_router_logged', True)
        else:
            track_list = []
            if "aux_loss" in args.moe_router_load_balancing_type:
                track_list.append("load_balancing_loss")
            if "seq_aux_loss" in args.moe_router_load_balancing_type:
                track_list.append("seq_load_balancing_loss")
            if "global_aux_loss" in args.moe_router_load_balancing_type:
                track_list.append("global_load_balancing_loss")
            if args.moe_z_loss_coeff is not None:
                track_list.append("z_loss")
            dirvae_track_names = track_list
        track_moe_metrics(
            loss_scale=moe_loss_scale,
            iteration=iteration,
            writer=writer,
            wandb_writer=wandb_writer,
            total_loss_dict=total_loss_dict,
            per_layer_logging=args.moe_per_layer_logging,
            force_initialize=True,
            track_names=dirvae_track_names,
            num_layers=args.num_layers,
            moe_layer_freq=args.moe_layer_freq,
            mtp_num_layers=args.mtp_num_layers,
        )
    if args.mtp_num_layers is not None:
        mtp_loss_scale = 1 / get_num_microbatches()
        MTPLossLoggingHelper.track_mtp_metrics(
            mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict
        )
    if iteration % args.log_interval == 0:
        if args.record_memory_history and is_last_rank():
            snapshot = torch.cuda.memory._snapshot()
            from pickle import dump

            with open(args.memory_snapshot_path, 'wb') as f:
                dump(snapshot, f)

        elapsed_time = timers('interval-time').elapsed(barrier=True)
        elapsed_time_per_iteration = elapsed_time / total_iterations

        throughput = num_floating_point_operations(args, batch_size) / (
            elapsed_time_per_iteration * 10**12 * args.world_size
        )

        one_logger_utils.track_e2e_metrics(args.log_throughput, throughput)

        if args.log_timers_to_tensorboard:
            if writer:
                writer.add_scalar('iteration-time', elapsed_time_per_iteration, iteration)
            if wandb_writer:
                wandb_writer.log({'iteration-time': elapsed_time_per_iteration}, iteration)
        log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
        log_string += ' iteration {:8d}/{:8d} |'.format(iteration, args.train_iters)
        log_string += ' consumed samples: {:12d} |'.format(args.consumed_train_samples)
        if args.skipped_train_samples > 0:
            log_string += ' skipped samples: {:12d} |'.format(args.skipped_train_samples)
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
            elapsed_time_per_iteration * 1000.0
        )
        if args.log_throughput:
            log_string += f' throughput per GPU (TFLOP/s/GPU): {throughput:.1f} |'
            if args.log_timers_to_tensorboard:
                if writer:
                    writer.add_scalar('throughput', throughput, iteration)
                if wandb_writer:
                    wandb_writer.log({'throughput': throughput}, iteration)
        if args.log_energy:
            energy = (energy_monitor.lap() / total_iterations) / args.world_size
            power = energy / elapsed_time_per_iteration
            log_string += f' energy per GPU (J/iter/GPU): {energy:.1f} |'
            log_string += f' power per GPU (W/GPU): {power:.1f} |'
            if writer:
                writer.add_scalar('iter-energy/gpu', energy, iteration)
                writer.add_scalar('power/gpu', power, iteration)
            if wandb_writer:
                wandb_writer.log({'iter-energy/gpu': energy}, iteration)
                wandb_writer.log({'power/gpu': power}, iteration)
        # Decoupled_learning_rate should be not None only on first and last pipeline stage.
        log_string += f' learning rate: {learning_rate:.6E} |'
        if args.decoupled_lr is not None and (
            mpu.is_pipeline_first_stage(ignore_virtual=True)
            or mpu.is_pipeline_last_stage(ignore_virtual=True)
        ):
            assert decoupled_learning_rate is not None
            log_string += f' decoupled learning rate: {decoupled_learning_rate:.6E} |'
        else:
            assert decoupled_learning_rate is None
        log_string += f' global batch size: {batch_size:5d} |'
        if moe_relu_sparsity is not None:
            log_string += f' moe relu sparsity: {moe_relu_sparsity:.4f} |'
        if moe_relu_l1_reg_coeff is not None:
            log_string += f' moe relu l1 reg coeff: {moe_relu_l1_reg_coeff:.4E} |'
        for key in total_loss_dict:
            if key not in [advanced_iters_key, skipped_iters_key, nan_iters_key]:
                avg = total_loss_dict[key].item() / float(
                    max(1, total_loss_dict[advanced_iters_key])
                )
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.tensor([0.0], dtype=torch.float, device='cuda')
        log_string += f' loss scale: {loss_scale:.1f} |'
        if grad_norm is not None:
            log_string += f' grad norm: {grad_norm:.3f} |'
        if num_zeros_in_grad is not None:
            log_string += f' num zeros: {num_zeros_in_grad} |'
        if params_norm is not None:
            log_string += f' params norm: {params_norm:.3f} |'
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key]
        )
        log_string += ' number of nan iterations: {:3d} |'.format(total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
        total_loss_dict[skipped_iters_key] = 0
        total_loss_dict[nan_iters_key] = 0
        print_rank_last(log_string)
        if report_memory_flag:
            # Report memory after optimizer state has been initialized.
            if torch.distributed.get_rank() == 0:
                num_microbatches = get_num_microbatches()
                report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True)
            report_memory(f'(after {iteration} iterations)')
            report_memory_flag = False
        # Write timers to wandb, don't reset the counts
        if args.log_timers_to_tensorboard:
            timers.write(timers_to_log, writer, iteration, normalizer=args.log_interval, reset=False)
            timers.write(timers_to_log, wandb_writer, iteration, normalizer=args.log_interval, reset=False)
        # Log timers to stdout
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


def compute_throughputs_and_append_to_progress_log(iteration, num_floating_point_operations_so_far):
    args = get_args()
    if args.save is None:
        return

    # Compute job throughput.
    # args.num_floating_point_operations_so_far keeps track of floating-point operations
    # completed at the start of job.
    global _TRAIN_START_TIME
    job_throughput = (
        num_floating_point_operations_so_far - args.num_floating_point_operations_so_far
    ) / ((time.time() - _TRAIN_START_TIME) * 10**12 * args.world_size)

    # Compute cumulative throughput since jobs of this world size were launched.
    # `get_start_time_from_progress_log` returns start time and number of floating-point
    # operations of first job of this world size.
    start_time, start_num_floating_point_operations = get_start_time_from_progress_log()
    elapsed_time = (datetime.now() - start_time).total_seconds()
    cumulative_throughput = (
        num_floating_point_operations_so_far - start_num_floating_point_operations
    ) / (elapsed_time * 10**12 * args.world_size)

    tokens_so_far = args.consumed_train_samples * args.seq_length
    saved_ckpt_prefix = 'Saving async checkpoint' if args.async_save else 'Saved checkpoint'
    append_to_progress_log(
        f"{saved_ckpt_prefix}\tIteration: {iteration}\t"
        f"Job throughput: {job_throughput:.1f} TFLOP/s/GPU\t"
        f"Cumulative throughput: {cumulative_throughput:.1f} TFLOP/s/GPU\t"
        f"Floating-point operations: {num_floating_point_operations_so_far:.2e}\t"
        f"Tokens (in billions): {tokens_so_far / 10**9:.2f}"
    )


def enable_forward_pre_hook(model_chunks):
    for model_chunk in model_chunks:
        assert isinstance(model_chunk, DDP)
        model_chunk.enable_forward_pre_hook()


def disable_forward_pre_hook(model_chunks, param_sync=True):
    for model_chunk in model_chunks:
        assert isinstance(model_chunk, DDP)
        model_chunk.disable_forward_pre_hook(param_sync=param_sync)


def save_checkpoint_and_time(
    iteration,
    model,
    optimizer,
    opt_param_scheduler,
    num_floating_point_operations_so_far,
    checkpointing_context,
    non_persistent_ckpt=False,
    train_data_iterator=None,
):
    args = get_args()
    timers = get_timers()
    energy_monitor = get_energy_monitor()

    # Stop timer to get accurate train interval time and exclude checkpointing duration
    timers('interval-time').stop()
    energy_monitor.pause()

    # Extra barrier is added to make sure all ranks report the max time.
    timer_key = 'save-checkpoint-non-persistent' if non_persistent_ckpt else 'save-checkpoint'
    timers(timer_key, log_level=0).start(barrier=True)

    # Log E2E metrics before save-checkpoint
    one_logger_utils.track_e2e_metrics()
    if should_disable_forward_pre_hook(args):
        disable_forward_pre_hook(model)
    save_checkpoint(
        iteration,
        model,
        optimizer,
        opt_param_scheduler,
        num_floating_point_operations_so_far,
        checkpointing_context,
        non_persistent_ckpt=non_persistent_ckpt,
        train_data_iterator=train_data_iterator,
        preprocess_common_state_dict_fn=preprocess_common_state_dict,
    )
    if args.fp8:
        # Run garbage collection after checkpoint saving to free memory from
        # dequantized bf16 tensors that were temporarily created during fp8
        # model checkpoint saving.
        gc.collect()
    if should_disable_forward_pre_hook(args):
        enable_forward_pre_hook(model)
    timers(timer_key).stop(barrier=True)
    timers.log([timer_key])

    # Log E2E metrics after save-checkpoint
    one_logger_utils.track_e2e_metrics()
    save_checkpoint_duration = timers(timer_key).elapsed()
    one_logger_utils.on_save_checkpoint_end(save_checkpoint_duration, iteration, args.async_save)

    if args.log_progress and not non_persistent_ckpt:
        compute_throughputs_and_append_to_progress_log(
            iteration, num_floating_point_operations_so_far
        )

    # Recover timing
    energy_monitor.resume()
    timers('interval-time', log_level=0).start(barrier=True)


def post_training_step_callbacks(
    model,
    optimizer,
    opt_param_scheduler,
    iteration,
    prof,
    num_floating_point_operations_since_last_log_event,
):
    """Run all post-training-step functions (e.g., FT heartbeats, GC)."""
    args = get_args()

    # Bring CPU and GPU back in sync if on right iteration.
    if args.train_sync_interval and iteration % args.train_sync_interval == 0:
        torch.cuda.synchronize()

    # Straggler detector.
    if iteration % args.log_interval == 0 and args.log_straggler:
        stimer.report(num_floating_point_operations_since_last_log_event, args.log_interval)
        num_floating_point_operations_since_last_log_event = 0.0

    # Check weight hash across DP replicas.
    if (
        args.check_weight_hash_across_dp_replicas_interval is not None
        and iteration % args.check_weight_hash_across_dp_replicas_interval == 0
    ):
        if should_disable_forward_pre_hook(args):
            disable_forward_pre_hook(model)
        assert check_param_hashes_across_dp_replicas(
            model, cross_check=True
        ), "Parameter hashes not matching across DP replicas"
        torch.distributed.barrier()
        print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
        if should_disable_forward_pre_hook(args):
            enable_forward_pre_hook(model)

    # Autoresume.
    if args.adlr_autoresume and (iteration % args.adlr_autoresume_interval == 0):
        check_adlr_autoresume_termination(iteration, model, optimizer, opt_param_scheduler)

    # Profiling.
    if (
        args.profile
        and iteration == args.profile_step_end
        and torch.distributed.get_rank() in args.profile_ranks
    ):
        if args.use_pytorch_profiler:
            assert prof is not None
            prof.stop()
        else:
            torch.cuda.cudart().cudaProfilerStop()

    # Manual garbage collection.
    if args.manual_gc:
        if args.manual_gc_interval != 0 and iteration % args.manual_gc_interval == 0:
            gc.collect()


def checkpoint_and_decide_exit(
    model,
    optimizer,
    opt_param_scheduler,
    iteration,
    num_floating_point_operations_so_far,
    checkpointing_context,
    train_data_iterator,
):
    """Save checkpoint and decide whether to exit based on arguments (e.g., if
    --exit-duration-in-mins is set). Actual exit happens in main training loop
    based on the return value of this function."""
    args = get_args()
    timers = get_timers()

    # Exit based on signal handler.
    saved_checkpoint = False
    if args.exit_signal_handler:
        signal_handler = get_signal_handler()
        if any(signal_handler.signals_received()):
            if args.save:
                save_checkpoint_and_time(
                    iteration,
                    model,
                    optimizer,
                    opt_param_scheduler,
                    num_floating_point_operations_so_far,
                    checkpointing_context,
                    train_data_iterator=train_data_iterator,
                )
            print_datetime('exiting program after receiving SIGTERM.')

            return True

    # Regular save (persistent and non-persistent).
    if args.save and args.save_interval and iteration % args.save_interval == 0:
        save_checkpoint_and_time(
            iteration,
            model,
            optimizer,
            opt_param_scheduler,
            num_floating_point_operations_so_far,
            checkpointing_context,
            train_data_iterator=train_data_iterator,
        )
        saved_checkpoint = True

    elif (
        args.save
        and args.non_persistent_save_interval
        and iteration % args.non_persistent_save_interval == 0
    ):
        save_checkpoint_and_time(
            iteration,
            model,
            optimizer,
            opt_param_scheduler,
            num_floating_point_operations_so_far,
            checkpointing_context,
            non_persistent_ckpt=True,
            train_data_iterator=train_data_iterator,
        )
        saved_checkpoint = True

    # Exit based on duration.
    if args.exit_duration_in_mins:
        train_time = (time.time() - _TRAIN_START_TIME) / 60.0
        done_cuda = torch.tensor(
            [train_time > args.exit_duration_in_mins], dtype=torch.int, device='cuda'
        )
        torch.distributed.all_reduce(done_cuda, op=torch.distributed.ReduceOp.MAX)
        done = done_cuda.item()
        if done:
            if args.save and not saved_checkpoint:
                save_checkpoint_and_time(
                    iteration,
                    model,
                    optimizer,
                    opt_param_scheduler,
                    num_floating_point_operations_so_far,
                    checkpointing_context,
                    train_data_iterator=train_data_iterator,
                )
            print_datetime(f'exiting program after {train_time} minutes')

            return True

    # Exit based on iterations.
    if args.exit_interval and iteration % args.exit_interval == 0:
        if args.save and not saved_checkpoint:
            save_checkpoint_and_time(
                iteration,
                model,
                optimizer,
                opt_param_scheduler,
                num_floating_point_operations_so_far,
                checkpointing_context,
                train_data_iterator=train_data_iterator,
            )
        print_datetime(f'exiting program at iteration {iteration}')

        return True

    return False


def train(
    forward_step_func,
    model,
    optimizer,
    opt_param_scheduler,
    train_data_iterator,
    valid_data_iterator,
    process_non_loss_data_func,
    config,
    checkpointing_context,
    non_loss_data_func,
):
    """Training function: run train_step desired number of times, run validation, checkpoint."""
    args = get_args()
    timers = get_timers()

    if getattr(args, 'perform_rl_step', False):
        assert has_rl_utils, "RL cannot run without the megatron.rl package"

    # Additional variable initialization for RL training
    if getattr(args, 'perform_rl_step', False):
        print_rank_0("> Loading pretrained checkpoint for reference weights in RL training...")
        load, finetune, no_load_optim = args.load, args.finetune, args.no_load_optim
        args.no_load_optim = True

        # Load pretrained checkpoint
        args.load = None
        args.finetune = True
        load_checkpoint(
                model,
                None,  # Don't load optimizer state
                None,  # Don't load scheduler state
                checkpointing_context=checkpointing_context,
                skip_load_to_model_and_opt=HAVE_FSDP2
                and getattr(args, "use_torch_fsdp2", False)
                and args.ckpt_format == "torch_dist",
            )
        ref_state_dict = {k: (v.cpu() if v is not None else v) for k, v in model[0].state_dict().items()}

        # Reload RL training checkpoint weights
        args.load = load
        args.finetune = finetune
        print_rank_0("> Reloading RL training checkpoint...")
        load_checkpoint(
                model,
                None,
                None,
                checkpointing_context=checkpointing_context,
                skip_load_to_model_and_opt=HAVE_FSDP2
                and getattr(args, "use_torch_fsdp2", False)
                and args.ckpt_format == "torch_dist",
            )

        args.no_load_optim = no_load_optim

    # IMPORTANT FIX: For RL training, reinitialize the microbatch calculator with the correct configuration
    if getattr(args, 'perform_rl_step', False):
        print_rank_0("> Reinitializing microbatch calculator for GRPO training...")
        from megatron.core.num_microbatches_calculator import (
            destroy_num_microbatches_calculator,
            init_num_microbatches_calculator
        )
        # First destroy the existing calculator
        destroy_num_microbatches_calculator()
        # Then initialize with the correct perform_rl_step=True context
        init_num_microbatches_calculator(
            args.rank,
            args.rampup_batch_size,
            args.global_batch_size,
            args.micro_batch_size,
            mpu.get_data_parallel_world_size(),
            args.decrease_batch_size_if_needed
        )
        print_rank_0(f"> GRPO training: num_microbatches set to {get_num_microbatches()}")

    energy_monitor = get_energy_monitor()
    one_logger = get_one_logger()

    if args.run_workload_inspector_server:
        try:
            from workload_inspector.utils.webserver import run_server
            import threading

            threading.Thread(
                target=run_server, daemon=True, args=(torch.distributed.get_rank(),)
            ).start()
        except ModuleNotFoundError:
            print_rank_0("workload inspector module not found.")

    # Write args to tensorboard
    write_args_to_tensorboard()

    # Turn on training mode which enables dropout.
    for model_module in model:
        model_module.train()

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration
    # Make sure rerun_state_machine has the right iteration loaded from checkpoint.
    rerun_state_machine = get_rerun_state_machine()
    if rerun_state_machine.current_iteration != iteration:
        print_rank_0(f"Overwriting rerun_state_machine.current_iteration from "
                     f"{rerun_state_machine.current_iteration} to {iteration}...")
        rerun_state_machine.current_iteration = iteration

    # Track E2E metrics at the start of training.
    one_logger_utils.on_train_start(
        iteration=iteration,
        consumed_train_samples=args.consumed_train_samples,
        train_samples=args.train_samples,
        seq_length=args.seq_length,
        train_iters=args.train_iters,
        save=args.save,
        async_save=args.async_save,
        log_throughput=args.log_throughput,
        num_floating_point_operations_so_far=args.num_floating_point_operations_so_far,
    )

    num_floating_point_operations_so_far = args.num_floating_point_operations_so_far

    # Setup some training config params.
    config.grad_scale_func = optimizer.scale_loss
    config.timers = timers
    if isinstance(model[0], (megatron_FSDP, DDP)) and args.overlap_grad_reduce:
        assert config.no_sync_func is None, (
            'When overlap_grad_reduce is True, config.no_sync_func must be None; '
            'a custom no_sync_func is not supported when overlapping grad-reduce'
        )
        config.no_sync_func = [model_chunk.no_sync for model_chunk in model]
        if len(model) == 1:
            config.no_sync_func = config.no_sync_func[0]
        if args.align_grad_reduce:
            config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model]
            if len(model) == 1:
                config.grad_sync_func = config.grad_sync_func[0]
    if args.overlap_param_gather and args.align_param_gather:
        config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model]
        if len(model) == 1:
            config.param_sync_func = config.param_sync_func[0]
    config.finalize_model_grads_func = finalize_model_grads

    if args.log_energy:
        energy_monitor.setup()
        energy_monitor.resume()

    timers('interval-time', log_level=0).start(barrier=True)
    print_datetime('before the start of training step')
    report_memory_flag = True
    pre_hook_enabled = False
    should_exit = False
    exit_code = 0

    if args.manual_gc:
        # Disable the default garbage collector and perform the collection manually.
        # This is to align the timing of garbage collection across ranks.
        assert (
            args.manual_gc_interval >= 0
        ), 'Manual garbage collection interval should be larger than or equal to 0'
        gc.disable()
        gc.collect()

    # Singleton initialization of straggler detector.
    if args.log_straggler:
        global stimer
        world = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()
        mmcnt = args.straggler_minmax_count
        stimer.configure(
            world,
            rank,
            mmcnt=mmcnt,
            enabled=not args.disable_straggler_on_startup,
            port=args.straggler_ctrlr_port,
        )
    num_floating_point_operations_since_last_log_event = 0.0

    num_microbatches = get_num_microbatches()
    eval_duration = 0.0
    eval_iterations = 0
    # Wrap forward_backward_func for Full iteration CUDA graph
    forward_backward_func = get_forward_backward_func()
    if args.enable_cuda_graph and args.cuda_graph_scope=="full_iteration":
        forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps)

    def get_e2e_base_metrics():
        """Get base metrics values for one-logger to calculate E2E tracking metrics."""
        num_floating_point_operations_since_current_train_start = (
            num_floating_point_operations_so_far - args.num_floating_point_operations_so_far
        )
        return {
            'iteration': iteration,
            'train_duration': timers('interval-time').active_time(),
            'eval_duration': eval_duration,
            'eval_iterations': eval_iterations,
            'total_flops_since_current_train_start': num_floating_point_operations_since_current_train_start,
            'num_floating_point_operations_so_far': num_floating_point_operations_so_far,
            'consumed_train_samples': args.consumed_train_samples,
            'world_size': args.world_size,
            'seq_length': args.seq_length,
        }

    # Cache into one-logger for callback.
    if one_logger:
        with one_logger.get_context_manager():
            one_logger.store_set('get_e2e_base_metrics', get_e2e_base_metrics)

    prof = None
    if (
        args.profile
        and torch.distributed.get_rank() in args.profile_ranks
        and args.use_pytorch_profiler
    ):
        prof = torch.profiler.profile(
            schedule=torch.profiler.schedule(
                wait=max(args.profile_step_start - 1, 0),
                warmup=1 if args.profile_step_start > 0 else 0,
                active=args.profile_step_end - args.profile_step_start,
                repeat=1,
            ),
            on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir),
            record_shapes=True,
            with_stack=True,
        )
        prof.start()

    start_iteration = iteration
    # Disable forward pre-hook to start training to ensure that errors in checkpoint loading
    # or random initialization don't propagate to all ranks in first all-gather (which is a
    # no-op if things work correctly).
    if should_disable_forward_pre_hook(args):
        disable_forward_pre_hook(model, param_sync=False)
        # Also remove param_sync_func temporarily so that sync calls made in
        # `forward_backward_func` are no-ops.
        param_sync_func = config.param_sync_func
        config.param_sync_func = None
        pre_hook_enabled = False
    # Also, check weight hash across DP replicas to be very pedantic.
    if args.check_weight_hash_across_dp_replicas_interval is not None:
        assert check_param_hashes_across_dp_replicas(
            model, cross_check=True
        ), "Parameter hashes not matching across DP replicas"
        torch.distributed.barrier()
        print_rank_0(f">>> Weight hashes match after {iteration} iterations...")

    # Capture CUDA Graphs.
    if args.external_cuda_graph:
        cuda_graph_helper = TECudaGraphHelper(
            model=model,
            config=config,
            seq_length=args.seq_length,
            micro_batch_size=args.micro_batch_size,
            optimizers=[optimizer],
        )
        cuda_graph_helper.create_cudagraphs()

    # Run training iterations till done.
    buffered_rollouts = None
    while iteration < args.train_iters:
        if args.profile and torch.distributed.get_rank() in args.profile_ranks:
            if args.use_pytorch_profiler:
                prof.step()
            elif iteration == args.profile_step_start:
                torch.cuda.cudart().cudaProfilerStart()
                torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()

        ft_integration.on_checkpointing_start()
        maybe_finalize_async_save(blocking=False)
        ft_integration.on_checkpointing_end(is_async_finalization=True)

        # Update number of microbatches first without consistency check to decide if a
        # checkpoint should be saved. If the number of microbatches is different
        # from the previous iteration, save a checkpoint. Then run consistency check
        # to make sure training configuration is still valid.
        update_num_microbatches(args.consumed_train_samples, consistency_check=False, verbose=True)
        if get_num_microbatches() != num_microbatches and iteration != 0:
            assert get_num_microbatches() > num_microbatches, (
                f"Number of microbatches should be increasing due to batch size rampup; "
                f"instead going from {num_microbatches} to {get_num_microbatches()}"
            )
            if args.save is not None:
                save_checkpoint_and_time(
                    iteration,
                    model,
                    optimizer,
                    opt_param_scheduler,
                    num_floating_point_operations_so_far,
                    checkpointing_context,
                    train_data_iterator=train_data_iterator,
                )
        num_microbatches = get_num_microbatches()
        update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True)

        # Completely skip iteration if needed.
        if iteration in args.iterations_to_skip:
            # Dummy train_step to fast forward train_data_iterator.
            dummy_train_step(train_data_iterator)
            iteration += 1
            batch_size = (
                mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
            )
            args.consumed_train_samples += batch_size
            args.skipped_train_samples += batch_size
            continue

        args.curr_iteration = iteration
        # For GRPO, we keep the data for a few epochs. DeepSeekMath paper calls this number $\mu$.
        # It is similar to a PPO epoch.

        if getattr(args, 'perform_rl_step', False):
            with torch.no_grad():
                # We collect new rollouts when we've gone over the collected data 'grpo_iterations' times.
                if iteration % (args.grpo_iterations * ((args.grpo_samples_per_iteration) // args.global_batch_size)) == 0:
                    buffered_rollouts = rl_utils.get_rollout_data_iterator(
                        model, optimizer, iteration, ref_state_dict,
                    )
                train_data_iterator = buffered_rollouts

        ft_integration.on_training_step_start()
        (
            loss_dict,
            skipped_iter,
            should_checkpoint,
            should_exit,
            exit_code,
            grad_norm,
            num_zeros_in_grad,
        ) = train_step(
            forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func
        )
        ft_integration.on_training_step_end()
        if should_checkpoint:
            save_checkpoint_and_time(
                iteration,
                model,
                optimizer,
                opt_param_scheduler,
                num_floating_point_operations_so_far,
                checkpointing_context,
                train_data_iterator=train_data_iterator,
            )
        if should_exit:
            break

        # Enable forward pre-hooks after first set of forward and backward passes.
        # When running in fp16, skip all NaN iterations until steady-state loss scaling value
        # is reached.
        if iteration == start_iteration:
            if skipped_iter:
                # Only enable forward pre-hook after a training step has successfully run. Relevant
                # for fp16 codepath where first XX iterations are skipped until steady-state loss
                # scale value is reached.
                start_iteration = iteration + 1
            else:
                # Enable forward pre-hook after training step has successfully run. All subsequent
                # forward passes will use the forward pre-hook / `param_sync_func` in
                # `forward_backward_func`.
                if should_disable_forward_pre_hook(args):
                    enable_forward_pre_hook(model)
                    config.param_sync_func = param_sync_func
                    pre_hook_enabled = True
                    # Set the manual hooks when CUDA Graphs are used.
                    if args.external_cuda_graph:
                        cuda_graph_helper.cuda_graph_set_manual_hooks()

        iteration += 1
        batch_size = (
            mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
        )
        args.consumed_train_samples += batch_size
        num_skipped_samples_in_batch = (
            get_current_global_batch_size() - get_current_running_global_batch_size()
        )
        if args.decrease_batch_size_if_needed:
            assert num_skipped_samples_in_batch >= 0
        else:
            assert num_skipped_samples_in_batch == 0
        args.skipped_train_samples += num_skipped_samples_in_batch
        num_floating_point_operations_in_batch = num_floating_point_operations(args, batch_size)
        num_floating_point_operations_so_far += num_floating_point_operations_in_batch
        num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch

        # Logging.
        if not optimizer.is_stub_optimizer:
            loss_scale = optimizer.get_loss_scale().item()
        else:
            loss_scale = 1.0
        params_norm = None

        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
        learning_rate = None
        decoupled_learning_rate = None
        for param_group in optimizer.param_groups:
            if len(param_group['params']) == 0:
                continue
            if param_group['is_decoupled_lr']:
                decoupled_learning_rate = param_group['lr']
            else:
                learning_rate = param_group['lr']
        # Prepare and update ReLU router metrics/coeff if enabled
        moe_relu_sparsity = None
        moe_relu_l1_reg_coeff = None
        if getattr(config, 'moe_relu_routing', False) or getattr(config, 'moe_router_type', None) == 'relu':
            if hasattr(config, 'moe_relu_sparsity') and hasattr(config, 'moe_relu_l1_reg_coeff'):
                # Average sparsity across layers and microbatches in this global batch
                try:
                    config.moe_relu_sparsity /= (config.num_layers * num_microbatches)
                except Exception:
                    pass
                # Average across DP, then sum over PP to get global average
                try:
                    torch.distributed.all_reduce(
                        config.moe_relu_sparsity,
                        group=mpu.get_data_parallel_group(),
                        op=torch.distributed.ReduceOp.SUM,
                    )
                    config.moe_relu_sparsity /= torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
                    torch.distributed.all_reduce(
                        config.moe_relu_sparsity,
                        group=mpu.get_pipeline_model_parallel_group(),
                        op=torch.distributed.ReduceOp.SUM,
                    )
                except Exception:
                    pass

                # Zeroth-order update of L1 coeff toward target sparsity
                try:
                    target_sparsity = 1 - (config.moe_router_topk / config.num_moe_experts)
                    if config.moe_relu_sparsity < target_sparsity:
                        config.moe_relu_l1_reg_coeff *= config.moe_relu_l1_reg_coeff_multiplier
                    else:
                        config.moe_relu_l1_reg_coeff /= config.moe_relu_l1_reg_coeff_multiplier
                except Exception:
                    pass

                # Read values for logging
                try:
                    moe_relu_sparsity = config.moe_relu_sparsity.item()
                except Exception:
                    moe_relu_sparsity = float(config.moe_relu_sparsity)
                try:
                    moe_relu_l1_reg_coeff = config.moe_relu_l1_reg_coeff.item()
                except Exception:
                    moe_relu_l1_reg_coeff = float(config.moe_relu_l1_reg_coeff)

        report_memory_flag = training_log(
            loss_dict,
            total_loss_dict,
            learning_rate,
            decoupled_learning_rate,
            iteration,
            loss_scale,
            report_memory_flag,
            skipped_iter,
            grad_norm,
            params_norm,
            num_zeros_in_grad,
            moe_relu_sparsity=moe_relu_sparsity,
            moe_relu_l1_reg_coeff=moe_relu_l1_reg_coeff,
        )

        # Evaluation.
        if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
            if args.log_energy:
                energy_monitor.pause()
            timers('interval-time').stop()
            if should_disable_forward_pre_hook(args):
                disable_forward_pre_hook(model)
                pre_hook_enabled = False
            if args.manual_gc and args.manual_gc_eval:
                # Collect all objects.
                gc.collect()
            prefix = f'iteration {iteration}'
            timers('eval-time', log_level=0).start(barrier=True)
            if getattr(args, 'perform_rl_step', False):
                rl_utils.evaluate_and_print_results_rl(valid_data_iterator, model, optimizer,
                                       iteration, write_to_tensorboard=True)
            else:
                evaluate_and_print_results(prefix, forward_step_func,
                                       valid_data_iterator, model,
                                       iteration, process_non_loss_data_func,
                                       config, verbose=False, write_to_tensorboard=True,
                                       non_loss_data_func=non_loss_data_func)

            eval_duration += timers('eval-time').elapsed()
            eval_iterations += sum(args.eval_iters) if isinstance(args.eval_iters, list) else args.eval_iters
            timers('eval-time').stop()
            one_logger_utils.track_e2e_metrics()

            if args.manual_gc and args.manual_gc_eval:
                # Collect only the objects created and used in evaluation.
                gc.collect(generation=0)
            if should_disable_forward_pre_hook(args):
                enable_forward_pre_hook(model)
                pre_hook_enabled = True
            timers('interval-time', log_level=0).start(barrier=True)
            if args.log_energy:
                energy_monitor.resume()

        # Zero the sparsity buffer if using ReLU router
        if getattr(config, 'moe_relu_routing', False) and hasattr(config, 'moe_relu_sparsity'):
            try:
                config.moe_relu_sparsity.zero_()
            except Exception:
                pass

        # Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
        # Some of these only happen at specific iterations.
        post_training_step_callbacks(
            model,
            optimizer,
            opt_param_scheduler,
            iteration,
            prof,
            num_floating_point_operations_since_last_log_event,
        )

        # Checkpoint and decide whether to exit.
        should_exit = checkpoint_and_decide_exit(
            model,
            optimizer,
            opt_param_scheduler,
            iteration,
            num_floating_point_operations_so_far,
            checkpointing_context,
            train_data_iterator,
        )
        if should_exit:
            break

    one_logger_utils.track_e2e_metrics()

    # Flush TensorBoard, WandB writers and one-logger.
    writer = get_tensorboard_writer()
    if writer:
        writer.flush()

    # Close out pre-hooks if using distributed optimizer and overlapped param gather.
    if pre_hook_enabled:
        disable_forward_pre_hook(model)

    ft_integration.on_checkpointing_start()
    # This will finalize all unfinalized async request and terminate
    # a persistent async worker if persistent ckpt worker is enabled
    maybe_finalize_async_save(blocking=True, terminate=True)
    ft_integration.on_checkpointing_end(is_async_finalization=True)
    if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
        ft_integration.get_rank_monitor_client().shutdown_workload_monitoring()

    if args.log_energy:
        energy_monitor.lap()
        total_energy = energy_monitor.get_total()
        print_rank_0(f"Total training energy (GPU): {total_energy / 1e6} MJ")
        energy_monitor.shutdown()

    # If any exit conditions (signal handler, duration, iterations) have been reached, exit.
    if should_exit:
        wandb_writer = get_wandb_writer()
        if wandb_writer:
            wandb_writer.finish()
        ft_integration.shutdown()
        one_logger_utils.finish()
        sys.exit(exit_code)

    return iteration, num_floating_point_operations_so_far


def evaluate(
    forward_step_func,
    data_iterator,
    model,
    process_non_loss_data_func,
    config,
    verbose=False,
    non_loss_data_func=None,
    eval_iters=None,
):
    """Evaluation."""
    args = get_args()
    timers = get_timers()

    timers('evaluate', log_level=0).start(barrier=True)

    if args.vision_pretraining and args.vision_pretraining_type == "dino":
        from megatron.legacy.model.vision.knn_monitor import compute_feature_bank

        compute_feature_bank(model)

    # Turn on evaluation mode which disables dropout.
    for model_module in model:
        model_module.eval()

    # Disable result validation during evaluation
    rerun_state_machine = get_rerun_state_machine()
    rerun_mode = rerun_state_machine.get_mode()
    rerun_state_machine.set_mode(RerunMode.DISABLED)

    total_loss_dict = {}

    # make validation batch size independent from training batch size
    eval_batch_size = args.global_batch_size
    eval_num_microbatches = eval_batch_size // (args.micro_batch_size * args.data_parallel_size)
    forward_backward_func = get_forward_backward_func()
    if args.enable_cuda_graph and args.cuda_graph_scope=="full_iteration":
        forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps)

    if eval_iters is None:
        eval_iters = args.eval_iters

    with torch.no_grad():
        iteration = 0
        if verbose:
            print_rank_0(f'Evaluating on {eval_iters * eval_batch_size} samples')
        while iteration < eval_iters:
            iteration += 1
            if verbose:
                print_rank_0(f'Evaluating iter {iteration}/{eval_iters}')

            # Don't care about timing during evaluation
            config.timers = None
            ft_integration.on_eval_step_start()
            loss_dicts = forward_backward_func(
                forward_step_func=forward_step_func,
                data_iterator=data_iterator,
                model=model,
                num_microbatches=eval_num_microbatches,
                seq_length=args.seq_length,
                micro_batch_size=args.micro_batch_size,
                decoder_seq_length=args.decoder_seq_length,
                forward_only=True,
            )
            ft_integration.on_eval_step_end()
            config.timers = get_timers()

            # Empty unused memory
            if args.empty_unused_memory_level >= 1:
                torch.cuda.empty_cache()

            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for key in loss_dicts[0].keys():
                    if key not in total_loss_dict:
                        total_loss_dict[key] = torch.tensor(
                            [0.0, 0.0], dtype=torch.float
                        ).cuda()
                    val = [x[key].view(-1) for x in loss_dicts]

                    if val[0].numel() == 2:
                        if args.sft:
                            # normalize over micro batch instead of global
                            val = torch.vstack(val)
                            val = val[:, 0] / val[:, 1]
                            val = val.mean()
                            torch.distributed.all_reduce(
                                val,
                                group=mpu.get_data_parallel_group(with_context_parallel=True)
                            )
                            val /= torch.distributed.get_world_size(
                                group=mpu.get_data_parallel_group(with_context_parallel=True)
                            )
                            total_loss_dict[key][0] += val
                            total_loss_dict[key][1] += 1
                        else :
                            val = torch.vstack(val).sum(dim=0)
                            torch.distributed.all_reduce(
                                val,
                                group=mpu.get_data_parallel_group(with_context_parallel=True)
                            )
                            total_loss_dict[key] += val
                    elif val[0].numel() == 1:
                        val = torch.cat(val).sum()
                        total_loss_dict[key][0] += val
                        total_loss_dict[key][1] += len(loss_dicts)
                    else:
                        raise ValueError(f"Invalid value shape: {val[0].shape} for key {key}")

            args.consumed_valid_samples += eval_batch_size

            if args.exit_duration_in_mins:
                train_time = (time.time() - _TRAIN_START_TIME) / 60.0
                done_cuda = torch.tensor(
                    [train_time > args.exit_duration_in_mins], dtype=torch.int, device='cuda'
                )
                torch.distributed.all_reduce(done_cuda, op=torch.distributed.ReduceOp.MAX)
                done = done_cuda.item()
                if done:
                    rerun_state_machine.set_mode(rerun_mode)
                    print_rank_0('Exiting during evaluation, timelimit reached')
                    return None, None, True

        collected_non_loss_data = None
        if non_loss_data_func is not None:
            collected_non_loss_data = non_loss_data_func(model)
        elif process_non_loss_data_func is not None and is_last_rank():
            collected_non_loss_data = forward_backward_func(
                forward_step_func=forward_step_func,
                data_iterator=data_iterator,
                model=model,
                num_microbatches=get_num_microbatches(),
                seq_length=args.seq_length,
                micro_batch_size=args.micro_batch_size,
                decoder_seq_length=args.decoder_seq_length,
                forward_only=True,
                collect_non_loss_data=True,
            )

    # Move model back to the train mode.
    for model_module in model:
        model_module.train()

    for key in total_loss_dict:
        numerator, denominator = total_loss_dict[key]
        total_loss_dict[key] = numerator / denominator

    timers('evaluate').stop()
    timers.log(['evaluate'])

    rerun_state_machine.set_mode(rerun_mode)

    rerun_state_machine.set_mode(rerun_mode)

    return total_loss_dict, collected_non_loss_data, False


def evaluate_and_print_results(
    prefix,
    forward_step_func,
    data_iterator,
    model,
    iteration,
    process_non_loss_data_func,
    config,
    verbose=False,
    write_to_tensorboard=True,
    non_loss_data_func=None,
):
    """Helper function to evaluate and dump results on screen."""
    args = get_args()
    if write_to_tensorboard:
        writer = get_tensorboard_writer()
    else:
        writer = None

    wandb_writer = get_wandb_writer()

    data_iterators = data_iterator if args.multiple_validation_sets else [data_iterator]

    if not args.multiple_validation_sets:
        eval_iters = [args.eval_iters]
    else:
        eval_iters = args.eval_iters
        
    if args.full_validation:
        assert len(eval_iters) == len(data_iterators)

        # with full validation we need to distribute eval_iters to all ranks
        if mpu.get_tensor_model_parallel_rank() == 0:
            eval_iters = torch.tensor(args.eval_iters, dtype=torch.long, device='cuda')
        else:
            eval_iters = torch.tensor([0] * len(eval_iters), dtype=torch.long, device='cuda')
        torch.distributed.broadcast(eval_iters, 0)
        eval_iters = eval_iters.tolist()
        args.eval_iters = eval_iters[0] if not args.multiple_validation_sets else eval_iters
    elif not args.multiple_validation_sets:
        eval_iters = [args.eval_iters]
    else:
        eval_iters = args.eval_iters
    
    for index, (iterator, iterations) in enumerate(zip(data_iterators, eval_iters)):
        suffix = ""
        if args.multiple_validation_sets:
            suffix = f"-{index}"
        total_loss_dict, collected_non_loss_data, timelimit = evaluate(
            forward_step_func,
            iterator,
            model,
            process_non_loss_data_func,
            config,
            verbose,
            non_loss_data_func,
            eval_iters=iterations,
        )
        # Timelimit hit during evaluation
        if timelimit:
            return
        string = f' validation{suffix} loss at {prefix} | '
        for key in total_loss_dict:
            string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
            ppl = math.exp(min(20, total_loss_dict[key].item()))
            string += '{} PPL: {:.6E} | '.format(key, ppl)
            if writer:
                writer.add_scalar('{} validation{}'.format(key, suffix), total_loss_dict[key].item(), iteration)
                writer.add_scalar(
                    '{} validation{} vs samples'.format(key, suffix),
                    total_loss_dict[key].item(),
                    args.consumed_train_samples,
                )
                if args.log_validation_ppl_to_tensorboard:
                    writer.add_scalar('{} validation{} ppl'.format(key, suffix), ppl, iteration)
                    writer.add_scalar(
                        '{} validation{} ppl vs samples'.format(key, suffix), ppl, args.consumed_train_samples
                    )
                if wandb_writer and is_last_rank():
                    wandb_writer.log(
                        {'{} validation{}'.format(key, suffix): total_loss_dict[key].item()}, iteration
                    )

        if process_non_loss_data_func is not None and writer and is_last_rank():
            process_non_loss_data_func(collected_non_loss_data, iteration, writer)

        length = len(string) + 1
        print_rank_last('-' * length)
        print_rank_last(string)
        print_rank_last('-' * length)


def cyclic_iter(iter):
    while True:
        for x in iter:
            yield x


def get_train_valid_test_num_samples():
    """Train/valid/test num samples."""

    args = get_args()

    # Number of train/valid/test samples.
    if args.train_samples:
        train_samples = args.train_samples
    else:
        train_samples = args.train_iters * args.global_batch_size
    if args.full_validation:
        eval_samples = None
    else:
        eval_iters = (args.train_iters // args.eval_interval + 1) * args.eval_iters
        eval_samples = eval_iters * args.global_batch_size
    test_iters = args.eval_iters

    return (train_samples, eval_samples, test_iters * args.global_batch_size)


def build_train_valid_test_datasets(build_train_valid_test_datasets_provider, train_valid_test_num_samples=None):
    """Build pretraining datasets."""
    if train_valid_test_num_samples is None:
        train_valid_test_num_samples = get_train_valid_test_num_samples()
    print_rank_0(' > datasets target sizes (minimum size):')
    print_rank_0('    train:      {}'.format(train_valid_test_num_samples[0]))
    print_rank_0('    validation: {}'.format(train_valid_test_num_samples[1]))
    print_rank_0('    test:       {}'.format(train_valid_test_num_samples[2]))
    return build_train_valid_test_datasets_provider(train_valid_test_num_samples)


def build_train_valid_test_data_loaders(build_train_valid_test_datasets_provider):
    """Build pretraining data loaders."""

    args = get_args()

    (train_dataloader, valid_dataloaders, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
        assert (
            args.train_samples is None
        ), 'Only backward compatiblity support for iteration-based training'
        args.consumed_train_samples = args.iteration * args.global_batch_size
    if args.iteration > 0 and args.consumed_valid_samples == 0:
        if args.train_samples is None:
            args.consumed_valid_samples = (
                (args.iteration // args.eval_interval) * args.eval_iters * args.global_batch_size
            )

    # Rely on distributed-aware core datasets, temporary
    is_distributed = getattr(build_train_valid_test_datasets_provider, "is_distributed", False)

    # Construct the data pipeline
    if is_distributed or mpu.get_tensor_model_parallel_rank() == 0:

        # Build datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
            build_train_valid_test_datasets_provider, (1, 1, 1) if getattr(args, 'perform_rl_step', False) else None
        )
        valid_ds = [valid_ds] if not isinstance(valid_ds, list) else valid_ds
        
        # Build dataloders.
        train_dataloader = build_pretraining_data_loader(train_ds, args.consumed_train_samples)

        valid_dataloaders = []
        for valid_d in valid_ds:
            if args.skip_train or args.full_validation:
                valid_dataloaders.append(build_pretraining_data_loader(valid_d, 0))
            else:
                if args.multiple_validation_sets:
                    # TODO(bnorick): for multiple validation sets without full validation, args.consumed_valid_samples is not
                    # correct and needs to be calculated/set per validation set
                    raise NotImplementedError("--multiple-validation-sets currently requires --full-validation")
                valid_dataloaders.append(build_pretraining_data_loader(valid_d, args.consumed_valid_samples))
        if not args.multiple_validation_sets:
            assert len(valid_dataloaders) == 1
        test_dataloader = build_pretraining_data_loader(test_ds, 0)

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloaders is not None and (args.full_validation or args.eval_iters > 0)
        do_test = test_dataloader is not None and (args.full_validation or args.eval_iters > 0)
        flags = torch.tensor(
            [int(do_train), int(do_valid), int(do_test)], dtype=torch.long, device='cuda'
        )
    else:
        flags = torch.tensor([0, 0, 0], dtype=torch.long, device='cuda')

    torch.distributed.broadcast(flags, 0)

    args.do_train = getattr(args, "do_train", False) or flags[0].item()
    args.do_valid = getattr(args, "do_valid", False) or flags[1].item()
    args.do_test = getattr(args, "do_test", False) or flags[2].item()
    if getattr(args, 'perform_rl_step', False):
        args.to_test = False

    return train_dataloader, valid_dataloaders, test_dataloader


def build_train_valid_test_data_iterators(build_train_valid_test_datasets_provider):
    """Build pretraining data iterators."""

    args = get_args()

    # Build loaders.
    train_dataloader, valid_dataloaders, test_dataloader = build_train_valid_test_data_loaders(
        build_train_valid_test_datasets_provider
    )

    # Build iterators.
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic', 'external']

    def _get_iterator(dataloader_type, dataloader):
        """Return dataset iterator."""
        if dataloader_type == "single":
            return RerunDataIterator(iter(dataloader))
        elif dataloader_type == "cyclic":
            return RerunDataIterator(iter(cyclic_iter(dataloader)))
        elif dataloader_type == "external":
            # External dataloader is passed through. User is expected to define how to iterate.
            if isinstance(dataloader, list):
                return [RerunDataIterator(d) for d in dataloader]
            else:
                return RerunDataIterator(dataloader)
        else:
            raise RuntimeError("unexpected dataloader type")

    if train_dataloader is not None:
        train_data_iterator = _get_iterator(dl_type, train_dataloader)
    else:
        train_data_iterator = None

    if valid_dataloaders is not None:
        # when using full validation, we need to override eval iters with the correct
        # number of iterations on tp rank 0 so that it can be distributed to the other 
        # ranks later
        if args.full_validation:
            if args.multiple_validation_sets:
                if valid_dataloaders[0] is None:
                    args.eval_iters = [None]*len(valid_dataloaders)
                else:
                    args.eval_iters = [len(dl) for dl in valid_dataloaders]
            else:
                args.eval_iters = len(valid_dataloaders[0])

        if args.multiple_validation_sets:
            if valid_dataloaders[0] is None:
                valid_data_iterators = [None] * len(valid_dataloaders)
            else:
                valid_dl_type = "cyclic" if args.full_validation else dl_type
                print(
                    f"[VALID DATA LOADER LENGTHS] "
                    ", ".join(f"{idx}: {len(dl)}" for idx, dl in enumerate(valid_dataloaders))
                )
                valid_data_iterators = [
                    _get_iterator(valid_dl_type, dl) for dl in valid_dataloaders
                ]
        elif valid_dataloaders[0] is not None:
            valid_data_iterators = _get_iterator(dl_type, valid_dataloaders[0])
        else:
            valid_data_iterators = None
    else:
        valid_data_iterators = None

    if test_dataloader is not None:
        test_data_iterator = _get_iterator(dl_type, test_dataloader)
    else:
        test_data_iterator = None

    return train_data_iterator, valid_data_iterators, test_data_iterator


def should_disable_forward_pre_hook(args):
    """Block forward pre-hook for certain configurations."""
    return not args.use_megatron_fsdp and args.use_distributed_optimizer and args.overlap_param_gather
