"""Input/output checkpointing."""
                                                              

"""Input/output checkpointing."""

import contextlib
import os
import random
import shutil
import sys
import threading
from enum import Enum, auto
from time import time
from logging import getLogger
from pathlib import Path
from packaging import version

import numpy as np
import torch

try:
    from modelopt.torch.opt.plugins import (
        save_modelopt_state,
        save_sharded_modelopt_state,
        restore_modelopt_state,
        restore_sharded_modelopt_state,
    )
except Exception:
    pass

from megatron.core import __version__
from megatron.core import mpu, tensor_parallel, dist_checkpointing
from megatron.core.dist_checkpointing.mapping import ShardedObject
from megatron.core.dist_checkpointing.serialization import get_default_load_sharded_strategy
from megatron.core.dist_checkpointing.strategies.fully_parallel import \
    FullyParallelSaveStrategyWrapper, FullyParallelLoadStrategyWrapper
from megatron.core.num_microbatches_calculator import update_num_microbatches
from megatron.core.rerun_state_machine import get_rerun_state_machine
from megatron.training.async_utils import schedule_async_save, is_empty_async_queue
from megatron.training.global_vars import get_args, get_one_logger
from megatron.training.utils import unwrap_model, print_rank_0, append_to_progress_log, is_last_rank
from megatron.core.dist_checkpointing.serialization import \
    get_default_save_sharded_strategy
from megatron.training.one_logger_utils import on_save_checkpoint_start, on_save_checkpoint_success
from megatron.training import wandb_utils
from megatron.training import ft_integration
from megatron.training.checkpointing import  \
    logger, get_distributed_optimizer_checkpoint_name, cleanup_old_non_persistent_checkpoint, _NON_PERSISTENT_CKPT_SUBDIR, \
    CheckpointType, get_rng_state, get_checkpoint_name, maybe_save_dataloader_state, ensure_directory_exists, has_nvidia_modelopt, \
    generate_state_dict, get_checkpoint_tracker_filename, checkpoint_exists, _load_base_checkpoint, fix_fp8_params_lose_precision_when_loading_dist_ckpt, \
    set_checkpoint_version, get_checkpoint_version, check_checkpoint_args, fix_query_key_value_ordering, read_metadata, \
    has_nvidia_modelopt
try:
    from megatron.training.checkpointing import _build_sharded_state_dict_metadata
    from megatron.core.msc_utils import open_file
except ImportError:
    _build_sharded_state_dict_metadata = None
    open_file = None


               
def save_user_train_data_consuming_progresses(args):
    if "train_data_consuming_progresses" in args and args.train_data_consuming_progresses is not None:
        g_rank = torch.distributed.get_rank()
        g_size = torch.distributed.get_world_size()
        progresses = [None for _ in range(g_size)]
        torch.distributed.all_gather_object(
            progresses,
            args.train_data_consuming_progresses.get(g_rank, None),
        )
        for pi, p in enumerate(progresses):
            args.train_data_consuming_progresses[pi] = p
        print_rank_0(f"saving train_data_consuming_progresses {args.train_data_consuming_progresses}")


def load_user_train_data_consuming_progresses(state_dict, args):
    if 'args' in state_dict \
            and 'train_data_consuming_progresses' in state_dict['args'] \
            and state_dict['args'].train_data_consuming_progresses is not None:
        if not hasattr(args, 'train_data_consuming_progresses'):
            args.train_data_consuming_progresses = {}
        args.train_data_consuming_progresses.clear()
                                            
        if args.finetune or args.px_clear_train_data_consuming_progresses:
            print_rank_0(f"train_data_consuming_progresses in checkpoint ignored")
        else:
            for k, v in state_dict['args'].train_data_consuming_progresses.items():
                if v is not None:
                    args.train_data_consuming_progresses[k] = v
            print_rank_0(f"use train_data_consuming_progresses in checkpoint {args.train_data_consuming_progresses}")
             


def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far,
                    checkpointing_context=None, pipeline_rank=None, expert_rank=None, tensor_rank=None, pipeline_parallel=None, expert_parallel=None, non_persistent_ckpt=False,
                    train_data_iterator=None, preprocess_common_state_dict_fn = None):
    """Save a model, optimizer and optionally dataloader checkpoint.

    Checkpointing context is used to persist some checkpointing state
    throughout a single job. Must be initialized externally (not used if None).

    If non_persistent_ckpt is True,
    the checkpoint will be saved with special functionality for removing old checkpoints.
    There are several types of non-persistent checkpoints:
    "global" - Saved as a standard checkpoint (e.g., on Lustre) with old checkpoints being removed.
    "local" - Each rank saves a portion of the checkpoint locally (e.g., on SSD/ramdisk).

    Dataloader checkpoint is only saved if the dataloader supports it. Currently this applies only
    to the Megatron Energon dataloader (multimodal) and not the built-in Megatron dataloader (text-only).
    """
    start_ckpt = time()
    args = get_args()

                  
    mcore_version_ge_0_13 = version.parse(__version__) >= version.parse('0.13.0')

    if mcore_version_ge_0_13:
        if args.async_save and not is_empty_async_queue():
            print_rank_0('WARNING: Starting a checkpoint save before previous has finished. Consider increasing the checkpoint interval.')
    else:
        if not is_empty_async_queue():
            print_rank_0('WARNING: Starting a checkpoint save before previous has finished. Consider increasing the checkpoint interval.')
                
                                                     
    productive_metrics = on_save_checkpoint_start(args.async_save)

                                                                        
    ft_integration.on_checkpointing_start()

                                                             
    model = unwrap_model(model)

                                                                          
                                                                                   
    ckpt_type = CheckpointType.GLOBAL if args.use_dist_ckpt else CheckpointType.LEGACY
    save_dir = args.save
    if non_persistent_ckpt:
        if args.non_persistent_ckpt_type == 'global':
            ckpt_type = CheckpointType.GLOBAL
            save_dir = (
                args.non_persistent_global_ckpt_dir
                if args.non_persistent_global_ckpt_dir
                else os.path.join(save_dir, _NON_PERSISTENT_CKPT_SUBDIR)
            )
                                                                                                                
            cleanup_old_non_persistent_checkpoint(
                save_dir, leave_ckpt_num=1, do_async=args.async_save
            )
        elif args.non_persistent_ckpt_type == 'local':
            ckpt_type = CheckpointType.LOCAL
            save_dir = checkpointing_context['local_checkpoint_manager'].local_ckpt_dir
        else:
            raise NotImplementedError(f"Please use local or global non-persistent checkpoints (got: {args.non_persistent_ckpt_type})")

    ckpt_format = args.ckpt_format if ckpt_type == CheckpointType.GLOBAL else 'torch'
    print_rank_0('saving checkpoint at iteration {:7d} to {} in {} format'.format(
        iteration, save_dir, ckpt_format))

                                                   
    if mcore_version_ge_0_13:
        rng_state = get_rng_state(args.ckpt_format)
    else:
        rng_state = get_rng_state(ckpt_type != CheckpointType.LEGACY)

                                          
    rerun_state_machine = get_rerun_state_machine()
    if version.parse(__version__) >= version.parse('0.12.0'):
        rerun_state = rerun_state_machine.state_dict(
            data_iterator=train_data_iterator, ckpt_format=args.ckpt_format,
        )
    else:
        rerun_state = rerun_state_machine.state_dict(
            data_iterator=train_data_iterator, use_dist_ckpt=ckpt_type != CheckpointType.LEGACY
        )

                      
    return_base_dir = (ckpt_type != CheckpointType.LEGACY)
    checkpoint_name = get_checkpoint_name(save_dir, iteration, release=False, pipeline_parallel=pipeline_parallel,
        tensor_rank=tensor_rank, pipeline_rank=pipeline_rank, expert_parallel=expert_parallel, expert_rank=expert_rank, return_base_dir=return_base_dir)

                   
                                                                  
    save_user_train_data_consuming_progresses(args)
                 

                                                                                            
    maybe_save_dataloader_state(train_data_iterator, iteration, getattr(args, "dataloader_save", None))

                                                          
    if (
        args.use_distributed_optimizer
        and not args.no_save_optim
        and optimizer is not None
        and ckpt_type == CheckpointType.LEGACY
    ):
        optim_checkpoint_name = \
            get_distributed_optimizer_checkpoint_name(checkpoint_name)
        ensure_directory_exists(optim_checkpoint_name)
        if not optimizer.is_stub_optimizer:
            optimizer.save_parameter_state(optim_checkpoint_name)

    async_save_request = None
    if args.async_save:
        if ckpt_type == CheckpointType.LEGACY:
            raise NotImplementedError('Async checkpoint save not implemented for legacy checkpoints')
        elif ckpt_type == CheckpointType.GLOBAL and args.ckpt_format != 'torch_dist':
            raise NotImplementedError(f'Async checkpoint save not implemented for {args.ckpt_format} distributed checkpoint format')

    rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0

                               
    if not torch.distributed.is_initialized() \
            or mpu.get_expert_data_parallel_rank() == 0 \
            or ckpt_type != CheckpointType.LEGACY:
        
        if mcore_version_ge_0_13:
            if ckpt_type != CheckpointType.LEGACY:
                assert _build_sharded_state_dict_metadata is not None
                sharded_sd_metadata = _build_sharded_state_dict_metadata(args)
                if args.use_distributed_optimizer:
                    print_rank_0(f'Storing distributed optimizer sharded state of type'
                                f' {sharded_sd_metadata["distrib_optim_sharding_type"]}')
            else:
                sharded_sd_metadata = None
            extra_args = {
                "optim_sd_kwargs": dict(metadata=sharded_sd_metadata),
                "model_sd_kwargs": dict(metadata=sharded_sd_metadata),
            }
        else:
            optim_sd_kwargs = {}
            if ckpt_type != CheckpointType.LEGACY and args.use_distributed_optimizer:
                optim_sd_kwargs['sharding_type'] = ('fully_sharded_model_space'
                                                    if args.ckpt_fully_parallel_save
                                                    else 'dp_zero_gather_scatter')
                print_rank_0(f'Storing distributed optimizer sharded state of type {optim_sd_kwargs["sharding_type"]}')
            extra_args = {
                "use_dist_ckpt": ckpt_type != CheckpointType.LEGACY,
                "optim_sd_kwargs": optim_sd_kwargs
            }

        state_dict = generate_state_dict(
            args,
            model,
            optimizer,
            opt_param_scheduler,
            rng_state,
            iteration=iteration,
            rerun_state=rerun_state,
            **extra_args,
        )

        state_dict['num_floating_point_operations_so_far'] = num_floating_point_operations_so_far

        if mcore_version_ge_0_13:
            tmp_flag = ckpt_type == CheckpointType.GLOBAL and ckpt_format == "torch_dist"
        else:
            tmp_flag = CheckpointType.GLOBAL
        if tmp_flag:
            if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
                                                                                        
                ensure_directory_exists(checkpoint_name, check_parent=False)
            if checkpointing_context is not None and 'save_strategy' in checkpointing_context:
                save_strategy = checkpointing_context['save_strategy']
                                                                                     
                validate_sharding_integrity = not args.ckpt_assume_constant_structure
            else:
                validate_sharding_integrity = True
                save_strategy = get_default_save_sharded_strategy(args.ckpt_format)
                if args.ckpt_assume_constant_structure and args.ckpt_format == 'torch_dist':
                    save_strategy.use_cached_ckpt_structure = args.ckpt_assume_constant_structure
                    if checkpointing_context is not None and 'load_strategy' in checkpointing_context:
                        cached_global_metadata = getattr(checkpointing_context['load_strategy'], 'cached_global_metadata', None)
                        if cached_global_metadata is not None:
                            logger.debug("Plugging in the read metadata from the load strategy...")
                            save_strategy.cached_global_metadata = cached_global_metadata
                        else:
                            logger.debug("Failed to plug in the read metadata from the load strategy...")

                if args.ckpt_fully_parallel_save:
                    save_strategy = FullyParallelSaveStrategyWrapper(save_strategy, mpu.get_data_parallel_group(with_context_parallel=True),
                                                                     args.ckpt_assume_constant_structure)
                                                             
            if checkpointing_context is not None:
                checkpointing_context['save_strategy'] = save_strategy
            end_ckpt = time()
            logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ")
            extra_args = {}
            if mcore_version_ge_0_13:
                extra_args = {"content_metadata": sharded_sd_metadata}
            async_save_request = dist_checkpointing.save(state_dict, checkpoint_name, save_strategy,
                                                         async_sharded_save=args.async_save,
                                                         validate_access_integrity=validate_sharding_integrity,
                                                         preprocess_common_before_consistancy_check=preprocess_common_state_dict_fn,
                                                         **extra_args)
                                                     
            if has_nvidia_modelopt:
                save_sharded_modelopt_state(model, checkpoint_name, (args.ckpt_format, 1))
        elif mcore_version_ge_0_13 and ckpt_type == CheckpointType.GLOBAL and ckpt_format == "torch_dcp":
            if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
                                                                                        
                ensure_directory_exists(checkpoint_name, check_parent=False)

            fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(checkpoint_name)
            torch.distributed.checkpoint.save(
                state_dict=state_dict,
                storage_writer=fs_storage_writer,
            )
        else:
                                                               
            if has_nvidia_modelopt:
                if ckpt_type == CheckpointType.LOCAL:
                    print_rank_0('WARNING: Local checkpointing does not support nvidia_modelopt.')
                else:
                    save_modelopt_state(model, state_dict)

            end_ckpt = time()
            logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ")
            if ckpt_type == CheckpointType.LOCAL:
                try:
                    from megatron.core.dist_checkpointing.tensor_aware_state_dict import MCoreTensorAwareStateDict
                except ModuleNotFoundError:
                    raise RuntimeError("The 'nvidia_resiliency_ext' module is required for local "
                                       "checkpointing but was not found. Please ensure it is installed.")

                algo = args.non_persistent_local_ckpt_algo
                cached_metadata = None
                if args.ckpt_assume_constant_structure and 'local_checkpoint_cache' in checkpointing_context:
                    cached_metadata = checkpointing_context['local_checkpoint_cache']
                state_dict_for_save, cacheable_metadata = MCoreTensorAwareStateDict.from_state_dict(
                    state_dict, algo=algo, cached_metadata=cached_metadata,
                    parallelization_group=mpu.get_data_parallel_group(with_context_parallel=True)
                )
                async_save_request = checkpointing_context['local_checkpoint_manager'].save(
                    state_dict_for_save, iteration, is_async=bool(args.async_save)
                )
                checkpointing_context['local_checkpoint_cache'] = cacheable_metadata
            else:
                assert ckpt_type == CheckpointType.LEGACY
                       
                ensure_directory_exists(checkpoint_name)
                torch.save(state_dict, checkpoint_name)
    start_misc = time()
    if ckpt_type != CheckpointType.LOCAL:
        if not args.async_save:
            assert async_save_request is None
                                                  
            if torch.distributed.is_initialized():
                torch.distributed.barrier()
                                     
    if not torch.distributed.is_initialized() \
            or torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(save_dir)

        if ckpt_type == CheckpointType.LOCAL:
            def iter_finalize_fn():
                print_rank_0('  successfully saved local checkpoint from iteration {:7d}'
                             .format(iteration))
                if args.log_progress and args.async_save:
                    append_to_progress_log(f'Saved async local checkpoint\tIteration: {iteration}',
                                           barrier=False)
        else:
            def iter_finalize_fn():
                if mcore_version_ge_0_13:
                    assert open_file is not None
                    open_func = open_file
                else:
                    open_func = open
                with open_func(tracker_filename, 'w') as f:
                    f.write(str(iteration))
                print_rank_0(f'  successfully saved checkpoint from iteration {int(iteration):7d} to {args.save} '
                             f'[ t {(tensor_rank if tensor_rank is not None else mpu.get_tensor_model_parallel_rank()) + 1}/{mpu.get_tensor_model_parallel_world_size()}, '
                             f'p {(pipeline_rank if pipeline_rank is not None else mpu.get_pipeline_model_parallel_rank()) + 1}/{mpu.get_pipeline_model_parallel_world_size()} ]')
                if args.log_progress and args.async_save:
                    append_to_progress_log(f'Saved async checkpoint\tIteration: {iteration}',
                                           barrier=False)

        if args.async_save:
            assert async_save_request is not None
            async_save_request.add_finalize_fn(iter_finalize_fn)
        else:
            iter_finalize_fn()

                                                    
    if not torch.distributed.is_initialized() \
       or is_last_rank():
        def onelogger_finalize_fn():
            on_save_checkpoint_success(productive_metrics, args.async_save)
        if args.async_save:
            assert async_save_request is not None
            async_save_request.add_finalize_fn(onelogger_finalize_fn)
        else:
            onelogger_finalize_fn()

                                               
    if not torch.distributed.is_initialized() \
       or is_last_rank():
        def wandb_finalize_fn():
            wandb_utils.on_save_checkpoint_success(checkpoint_name, get_checkpoint_tracker_filename(save_dir), save_dir, iteration)
        if args.async_save:
            assert async_save_request is not None
            async_save_request.add_finalize_fn(wandb_finalize_fn)
        else:
            wandb_finalize_fn()

    if args.async_save:
        schedule_async_save(async_save_request)
        print_rank_0('  scheduled an async checkpoint save at iteration {:7d} to {}' \
                     .format(iteration, save_dir))

                                              
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    end_misc = time()
    logger.debug(f"rank: {rank}, takes {end_misc - start_misc} to finalize ckpt save ")

    ft_integration.on_checkpointing_end(is_async_finalization=False)


def load_checkpoint_lt_0_13(model, optimizer, opt_param_scheduler, load_arg='load', strict=True,
                            checkpointing_context=None, skip_load_to_model_and_opt=False):
    """Load a model checkpoint and return the iteration.
    strict (bool): whether to strictly enforce that the keys in
        :attr:`state_dict` of the checkpoint match the names of
        parameters and buffers in model.
    skip_load_to_model_and_opt (bool): whether to call `load_state_dict`
        for :attr:`model` and :attr:`optimizer`. In case of running FSDP2
        or other torch features that uses DTensor in state dict, the tensors
        are already loaded in-place by `_load_base_checkpoint`.
    """
    args = get_args()
    load_dir = getattr(args, load_arg)

                            
    pretrained_dir = getattr(args, 'pretrained_checkpoint', None)
    if pretrained_dir is not None and not checkpoint_exists(load_dir):
        print_rank_0(
            f'Checkpoint file not found in load directory {load_dir} attempting to finetune with checkpoint in {pretrained_dir}'
        )
        load_dir = pretrained_dir
        if not checkpoint_exists(load_dir):
            raise FileNotFoundError("No checkpoint found in load directory or pretrained directory")
        args.finetune = True
    model = unwrap_model(model)
    load_kwargs = {}
    is_dist_ckpt = False
    if (
        args.auto_detect_ckpt_format
        or args.use_dist_ckpt
        or args.non_persistent_save_interval is not None
    ):
        state_dict, checkpoint_name, release, ckpt_type = _load_base_checkpoint(
            load_dir,
            args,
            rank0=True,
            checkpointing_context=checkpointing_context,
        )
        is_dist_ckpt = (
            ckpt_type == CheckpointType.LOCAL
            or dist_checkpointing.check_is_distributed_checkpoint(checkpoint_name)
        )
        if is_dist_ckpt:
            ckpt_tp_pp = (
                state_dict['args'].tensor_model_parallel_size,
                state_dict['args'].pipeline_model_parallel_size,
                getattr(state_dict['args'], 'encoder_tensor_model_parallel_size', 0),
                getattr(state_dict['args'], 'encoder_pipeline_model_parallel_size', 0),
            )
            run_tp_pp = (
                args.tensor_model_parallel_size,
                args.pipeline_model_parallel_size,
                                                                                                
                getattr(args, 'encoder_tensor_model_parallel_size', 0),
                getattr(args, 'encoder_pipeline_model_parallel_size', 0),
            )
            mismatch_msg = "(TP, PP, encoder TP, encoder PP) mismatch after resume ({} vs {} from checkpoint)".format(
                run_tp_pp, ckpt_tp_pp
            )
                                                   
            if (ckpt_tp_pp == run_tp_pp and not release and not args.finetune and not args.no_load_rng
                    and not getattr(state_dict['args'], 'no_save_rng', False)):
                gen_sd_rng_state = get_rng_state(True)                             
            else:
                gen_sd_rng_state = None
                if ckpt_tp_pp != run_tp_pp:
                    print_rank_0("{}: RNG state will be ignored".format(mismatch_msg))
            optim_sd_kwargs = dict(is_loading=True)
                                                         
            if (not release and not args.finetune and not args.no_load_optim
                    and not getattr(state_dict['args'], 'no_save_optim', False)):
                gen_sd_optim = optimizer
                gen_sd_opt_param_scheduler = opt_param_scheduler
                if args.use_distributed_optimizer:
                    optim_sd_kwargs['sharding_type'] = ('fully_sharded_model_space'
                                                        if getattr(state_dict['args'], 'ckpt_fully_parallel_save', False)
                                                        else 'dp_zero_gather_scatter')
                                                                                                                              
                    for maybe_dist_opt_optim_state in (state_dict['optimizer'], *state_dict['optimizer'].values()):
                        if 'param_state_sharding_type' in maybe_dist_opt_optim_state:
                            if maybe_dist_opt_optim_state['param_state_sharding_type'] == 'fully_sharded_bucket_space':
                                print_rank_0('Detected deprecated `fully_sharded_bucket_space` DistributedOptimizer checkpoint format')
                                optim_sd_kwargs['sharding_type'] = maybe_dist_opt_optim_state['param_state_sharding_type']
                            break
                    if ckpt_tp_pp != run_tp_pp and optim_sd_kwargs['sharding_type'] != 'fully_sharded_model_space':
                        raise RuntimeError(f"{mismatch_msg}: not supported for DistributedOptimizer with sharding type {optim_sd_kwargs['sharding_type']}."
                                           f" Please use `--ckpt-fully-parallel-save` flag during checkpoint saving.")
            else:
                gen_sd_optim = None
                gen_sd_opt_param_scheduler = None
                                                     
            if (ckpt_tp_pp == run_tp_pp and not release and not args.finetune):
                rerun_state_machine = get_rerun_state_machine()
                if version.parse(__version__) >= version.parse('0.12.0'):
                    gen_sd_rerun_state = rerun_state_machine.state_dict(
                        data_iterator=None, ckpt_format=args.ckpt_format,
                    )
                else:
                    gen_sd_rerun_state = rerun_state_machine.state_dict(
                        data_iterator=None, use_dist_ckpt=True
                    )
            else:
                gen_sd_rerun_state = None
                if ckpt_tp_pp != run_tp_pp:
                    print_rank_0("{}: Rerun state will be ignored".format(mismatch_msg))
                                                                                                
                                                                                                   
            if has_nvidia_modelopt:
                if ckpt_type == CheckpointType.LOCAL:
                    print_rank_0('WARNING: Local checkpointing does not support nvidia_modelopt.')
                elif ckpt_type == CheckpointType.GLOBAL:
                    restore_modelopt_state(model, state_dict)
                else:
                    restore_sharded_modelopt_state(model, checkpoint_name)
                                                                                                    
                                                                                                    
                                                                                                        
            with contextlib.ExitStack() as stack:                                                         
                if args.finetune and hasattr(model[0], "hide_loss_modules"):
                    for m in model:
                        stack.enter_context(m.hide_loss_modules())
                load_kwargs['sharded_state_dict'] = generate_state_dict(
                    args, model, gen_sd_optim, gen_sd_opt_param_scheduler, gen_sd_rng_state,
                    use_dist_ckpt=True, optim_sd_kwargs=optim_sd_kwargs, rerun_state=gen_sd_rerun_state
                )
                                                                                           
            fix_fp8_params_lose_precision_when_loading_dist_ckpt(load_kwargs['sharded_state_dict'])
    state_dict, checkpoint_name, release, ckpt_type = _load_base_checkpoint(
        load_dir, args, rank0=False, checkpointing_context=checkpointing_context,
        **load_kwargs
    )
                            
    if state_dict is None:
                                                                          
        return 0, 0
 
                             
    set_checkpoint_version(state_dict.get('checkpoint_version', 0))

                   
                                                                             
    load_user_train_data_consuming_progresses(state_dict, args)
                 

                    
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = state_dict['iteration']
        except KeyError:
            try:                                              
                iteration = state_dict['total_iters']
            except KeyError:
                print_rank_0('A metadata file exists but unable to load '
                             'iteration from checkpoint {}, exiting'.format(checkpoint_name))
                sys.exit()
    num_floating_point_operations_so_far = state_dict.get('num_floating_point_operations_so_far', 0)

                      
    assert args.consumed_train_samples == 0
    assert args.skipped_train_samples == 0
    assert args.consumed_valid_samples == 0
    if 'args' in state_dict and not args.finetune:
        checkpoint_args = state_dict['args']
        check_checkpoint_args(checkpoint_args)
        args.consumed_train_samples = getattr(checkpoint_args,
                                              'consumed_train_samples', 0)
        args.skipped_train_samples = getattr(checkpoint_args,
                                             'skipped_train_samples', 0)
        update_num_microbatches(consumed_samples=args.consumed_train_samples, verbose=True)
        args.consumed_valid_samples = getattr(checkpoint_args,
                                              'consumed_valid_samples', 0)
    else:
        print_rank_0('could not find arguments in the checkpoint ...')

            
    strict = False if args.retro_add_retriever else strict
    if not skip_load_to_model_and_opt:
        if len(model) == 1:
            model[0].load_state_dict(state_dict['model'], strict=strict)
        else:
            for i in range(len(model)):
                mpu.set_virtual_pipeline_model_parallel_rank(i)
                model[i].load_state_dict(state_dict['model%d' % i], strict=strict)

                                                       
    checkpoint_version = get_checkpoint_version()
    print_rank_0(f' checkpoint version {checkpoint_version}')
    fix_query_key_value_ordering(model, checkpoint_version)

                
    if not release and not args.finetune and not args.no_load_optim:
        try:
                              
            if not skip_load_to_model_and_opt and optimizer is not None and not optimizer.is_stub_optimizer:
                optimizer.load_state_dict(state_dict['optimizer'])

                                                                  
                                                                                     
            if args.use_distributed_optimizer and not is_dist_ckpt:
                                                                  
                                                                                               
                assert not is_dist_ckpt
                tracker_filename = get_checkpoint_tracker_filename(load_dir)
                iteration, release = read_metadata(tracker_filename)
                model_checkpoint_name = \
                    get_checkpoint_name(load_dir, iteration, release)
                optim_checkpoint_name = \
                    get_distributed_optimizer_checkpoint_name(
                        model_checkpoint_name)
                optimizer.load_parameter_state(optim_checkpoint_name,
                                               update_legacy_format=args.ckpt_convert_update_legacy_dist_opt_format)

                             
            if opt_param_scheduler is not None:
                if 'lr_scheduler' in state_dict:                        
                    opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
                else:
                    opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
        except KeyError as e:
            print_rank_0('Unable to load optimizer from checkpoint {}. '
                         'Specify --no-load-optim or --finetune to prevent '
                         'attempting to load the optimizer state, '
                         'exiting ...'.format(checkpoint_name))
            raise e
    else:
        if (args.fp16 or args.bf16) and optimizer is not None:
            optimizer.reload_model_params()

                 
    try:
        if 'rerun_state_machine' in state_dict:
            get_rerun_state_machine().load_state_dict(state_dict['rerun_state_machine'])
    except Exception as e:
        print(f"Unable to restore RerunMachine from checkpoint: {e}")
        sys.exit()

                 
    if not release and not args.finetune and not args.no_load_rng:
        try:
            if 'rng_state' in state_dict:
                                                         
                if args.data_parallel_random_init:
                    rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
                else:
                    rng_state = state_dict['rng_state'][0]
                random.setstate(rng_state['random_rng_state'])
                np.random.set_state(rng_state['np_rng_state'])
                torch.set_rng_state(rng_state['torch_rng_state'])
                torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
                                              
                if not rng_state['rng_tracker_states']:
                    raise KeyError
                tensor_parallel.get_cuda_rng_tracker().set_states(
                    rng_state['rng_tracker_states'])
            else:                          
                random.setstate(state_dict['random_rng_state'])
                np.random.set_state(state_dict['np_rng_state'])
                torch.set_rng_state(state_dict['torch_rng_state'])
                torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
                                              
                if not state_dict['rng_tracker_states']:
                    raise KeyError
                tensor_parallel.get_cuda_rng_tracker().set_states(
                    state_dict['rng_tracker_states'])
        except KeyError:
            print_rank_0('Unable to load rng state from checkpoint {}. '
                         'Specify --no-load-rng or --finetune to prevent '
                         'attempting to load the rng state, '
                         'exiting ...'.format(checkpoint_name))
            sys.exit()

                                                                                    
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    print_rank_0(f'  successfully loaded checkpoint from {load_dir} '
                 f'[ t {mpu.get_tensor_model_parallel_rank() + 1}/{mpu.get_tensor_model_parallel_world_size()}, '
                 f'p {mpu.get_pipeline_model_parallel_rank() + 1}/{mpu.get_pipeline_model_parallel_world_size()} ] '
                 f'at iteration {iteration}')

                                               
    if not torch.distributed.is_initialized() \
       or is_last_rank():
        wandb_utils.on_load_checkpoint_success(checkpoint_name, load_dir)

    torch.cuda.empty_cache()

    if iteration > 0:
                                                 
        is_local_chkpt = (ckpt_type == CheckpointType.LOCAL)
        ft_integration.on_checkpoint_loaded(is_local_chkpt=is_local_chkpt)

    return iteration, num_floating_point_operations_so_far


def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', strict=True,
                    checkpointing_context=None, skip_load_to_model_and_opt=False):
    """Load a model checkpoint and return the iteration.
    strict (bool): whether to strictly enforce that the keys in
        :attr:`state_dict` of the checkpoint match the names of
        parameters and buffers in model.
    skip_load_to_model_and_opt (bool): whether to call `load_state_dict`
        for :attr:`model` and :attr:`optimizer`. In case of running FSDP2 with mcore distributed
        checkpointing, the tensors are already loaded in-place by `_load_base_checkpoint`.
    """
    args = get_args()
    load_dir = getattr(args, load_arg)

                                        
    if hasattr(args, 'load_model_opt_format') and args.load_model_opt_format:
        print_rank_0(f'Loading checkpoint using ModelOpt format from {load_dir}')
        from megatron.post_training.checkpointing import load_modelopt_checkpoint

                                                       
        load_modelopt_checkpoint(
            ddp_model,
            optimizer=optimizer,
            opt_param_scheduler=opt_param_scheduler,
            strict=strict,
            load_arg=load_arg
        )
        
                                                                                          
        if torch.distributed.is_initialized():
            tracker_filename = get_checkpoint_tracker_filename(load_dir)
            if os.path.isfile(tracker_filename):
                iteration, release = read_metadata(tracker_filename)
                if release:
                    iteration = 0
            else:
                iteration = 0
        else:
            iteration = 0
        
                                                                                                       
        return iteration, 0

                            
    pretrained_dir = getattr(args, 'pretrained_checkpoint', None)
    if pretrained_dir is not None and not checkpoint_exists(load_dir):
        print_rank_0(
            f'Checkpoint file not found in load directory {load_dir} attempting to finetune with checkpoint in {pretrained_dir}'
        )
        load_dir = pretrained_dir
        if not checkpoint_exists(load_dir):
            raise FileNotFoundError("No checkpoint found in load directory or pretrained directory")
        args.finetune = True

    model = unwrap_model(ddp_model)

    ckpt_format = args.ckpt_format
    if args.auto_detect_ckpt_format or ckpt_format == "torch_dist":
        state_dict, checkpoint_name, release, ckpt_type = _load_base_checkpoint(
            load_dir,
            args,
            rank0=True,
            checkpointing_context=checkpointing_context,
        )

        ckpt_format = None
        if ckpt_type == CheckpointType.TORCH_DCP:
            ckpt_format = "torch_dcp"
        elif ckpt_type == CheckpointType.LEGACY:
            ckpt_format = "torch"
        elif ckpt_type in [CheckpointType.LOCAL, CheckpointType.GLOBAL]:
            ckpt_format = "torch_dist"
        elif ckpt_type == None:
            pass                 
        else:
            raise NotImplementedError(f"checkpoint format {ckpt_format} not supported")

    load_kwargs = {}
    if ckpt_format == "torch_dist":
        ckpt_tp_pp = (
            state_dict['args'].tensor_model_parallel_size,
            state_dict['args'].pipeline_model_parallel_size,
            getattr(state_dict['args'], 'encoder_tensor_model_parallel_size', 0),
            getattr(state_dict['args'], 'encoder_pipeline_model_parallel_size', 0),
        )
        run_tp_pp = (
            args.tensor_model_parallel_size,
            args.pipeline_model_parallel_size,
                                                                                            
            getattr(args, 'encoder_tensor_model_parallel_size', 0),
            getattr(args, 'encoder_pipeline_model_parallel_size', 0),
        )
        mismatch_msg = "(TP, PP, encoder TP, encoder PP) mismatch after resume ({} vs {} from checkpoint)".format(
            run_tp_pp, ckpt_tp_pp
        )

                                               
        if (ckpt_tp_pp == run_tp_pp and not release and not args.finetune and not args.no_load_rng
                and not getattr(state_dict['args'], 'no_save_rng', False)):
            gen_sd_rng_state = get_rng_state(args.ckpt_format)                             
        else:
            gen_sd_rng_state = None
            if ckpt_tp_pp != run_tp_pp:
                print_rank_0("{}: RNG state will be ignored".format(mismatch_msg))

        sharded_sd_metadata = dist_checkpointing.load_content_metadata(preloaded_state_dict=state_dict)
        print_rank_0(f'sharded_state_dict metadata loaded from the checkpoint: {sharded_sd_metadata}')
                                                     
        if (not release and not args.finetune and not args.no_load_optim
                and not getattr(state_dict['args'], 'no_save_optim', False)):
            gen_sd_optim = optimizer
            gen_sd_opt_param_scheduler = opt_param_scheduler

            if args.use_distributed_optimizer:
                if sharded_sd_metadata is None:
                                                                                                     
                                                                                                          
                                                                                      
                    sharded_sd_metadata = {
                        'distrib_optim_sharding_type': ('fully_sharded_model_space'
                                                        if getattr(state_dict['args'], 'ckpt_fully_parallel_save', False)
                                                        else 'dp_zero_gather_scatter'),
                    }
                if ckpt_tp_pp != run_tp_pp and sharded_sd_metadata['distrib_optim_sharding_type'] != 'fully_sharded_model_space':
                    raise RuntimeError(f"{mismatch_msg}: not supported for DistributedOptimizer with sharding type"
                                       f" {sharded_sd_metadata['distrib_optim_sharding_type']}."
                                       f" Please use `--ckpt-fully-parallel-save` flag during checkpoint saving.")
        else:
            gen_sd_optim = None
            gen_sd_opt_param_scheduler = None

        optim_sd_kwargs = dict(metadata=sharded_sd_metadata, is_loading=True)
        model_sd_kwargs = dict(metadata=sharded_sd_metadata)

                                                 
        if (
            ckpt_tp_pp == run_tp_pp
            and not release
            and not args.finetune
            and 'rerun_state_machine' in state_dict
        ):
            rerun_state_machine = get_rerun_state_machine()
            gen_sd_rerun_state = rerun_state_machine.state_dict(
                data_iterator=None, ckpt_format=ckpt_format,
            )
        else:
            gen_sd_rerun_state = None
            if ckpt_tp_pp != run_tp_pp:
                print_rank_0("{}: Rerun state will be ignored".format(mismatch_msg))

                                                                                            
                                                                                               
        if has_nvidia_modelopt:
            if ckpt_type == CheckpointType.LOCAL:
                print_rank_0('WARNING: Local checkpointing does not support nvidia_modelopt.')
            elif ckpt_type == CheckpointType.GLOBAL:
                restore_modelopt_state(model, state_dict)
            else:
                restore_sharded_modelopt_state(model, checkpoint_name)

                                                                                                
                                                                                                
                                                                                                    
        with contextlib.ExitStack() as stack:                                                         
            if args.finetune and hasattr(model[0], "hide_loss_modules"):
                for m in model:
                    stack.enter_context(m.hide_loss_modules())
            load_kwargs['sharded_state_dict'] = generate_state_dict(
                args, model, gen_sd_optim, gen_sd_opt_param_scheduler, gen_sd_rng_state,
                optim_sd_kwargs=optim_sd_kwargs, model_sd_kwargs=model_sd_kwargs,
                rerun_state=gen_sd_rerun_state
            )

                                                                                       
        fix_fp8_params_lose_precision_when_loading_dist_ckpt(load_kwargs['sharded_state_dict'])
    elif args.ckpt_format == "torch_dcp":
        model_sd = model[0].state_dict()
        optimizer_sd = optimizer.state_dict(is_loading=True)
        sharded_state_dict = {
            "model": model_sd,
            "optimizer": optimizer_sd,
            "args": None,
            "iteration": 1,
            "rng_state": get_rng_state(args.ckpt_format),
            "checkpoint_version": None,
            "opt_param_scheduler": opt_param_scheduler.state_dict(),
            "num_floating_point_operations_so_far": 0,
        }
        load_kwargs["sharded_state_dict"] = sharded_state_dict

    state_dict, checkpoint_name, release, ckpt_type = _load_base_checkpoint(
        load_dir, args, rank0=False, checkpointing_context=checkpointing_context,
        **load_kwargs
    )

                            
    if state_dict is None:
                                                                          
        return 0, 0

                             
    set_checkpoint_version(state_dict.get('checkpoint_version', 0))

                   
                                                                             
    load_user_train_data_consuming_progresses(state_dict, args)
                 

                                                 
    if ckpt_type == CheckpointType.LEGACY and args.ckpt_format == "torch_dcp":
        dtensor_state_dict = _to_dtensor(ddp_model, state_dict["model"])
        state_dict["model"] = dtensor_state_dict

                    
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = state_dict['iteration']
        except KeyError:
            try:                                              
                iteration = state_dict['total_iters']
            except KeyError:
                print_rank_0('A metadata file exists but unable to load '
                             'iteration from checkpoint {}, exiting'.format(checkpoint_name))
                sys.exit()
    num_floating_point_operations_so_far = state_dict.get('num_floating_point_operations_so_far', 0)

                      
    assert args.consumed_train_samples == 0
    assert args.skipped_train_samples == 0
    assert args.consumed_valid_samples == 0
    if 'args' in state_dict and not args.finetune:
        checkpoint_args = state_dict['args']
        check_checkpoint_args(checkpoint_args)
        args.consumed_train_samples = getattr(checkpoint_args,
                                              'consumed_train_samples', 0)
        args.skipped_train_samples = getattr(checkpoint_args,
                                             'skipped_train_samples', 0)
        update_num_microbatches(consumed_samples=args.consumed_train_samples, verbose=True)
        args.consumed_valid_samples = getattr(checkpoint_args,
                                              'consumed_valid_samples', 0)
    else:
        print_rank_0('could not find arguments in the checkpoint ...')

    def load_model_state_dict(module, state_dict, strict: bool):
        """Helper function to load state dict with fallback for missing extra states."""
        try:
            module.load_state_dict(state_dict, strict=strict)
        except Exception as e:
            if strict:
                                                                                                   
                load_return = module.load_state_dict(state_dict, strict=False)
                print(f"load_return: {load_return}")
            
    strict = False if args.retro_add_retriever else strict
    if not skip_load_to_model_and_opt:
        if len(ddp_model) == 1:
            load_model_state_dict(ddp_model[0], state_dict['model'], strict)
        else:
            for i in range(len(ddp_model)):
                                                                                           
                                                       
                if 'model%d' % i not in state_dict:
                    continue
                load_model_state_dict(ddp_model[i], state_dict['model%d' % i], strict)
                                                       
    checkpoint_version = get_checkpoint_version()
    print_rank_0(f' checkpoint version {checkpoint_version}')
    fix_query_key_value_ordering(model, checkpoint_version)

                
    if not release and not args.finetune and not args.no_load_optim:
        try:
                              
            if not skip_load_to_model_and_opt and optimizer is not None and not optimizer.is_stub_optimizer:
                optimizer.load_state_dict(state_dict['optimizer'])

                                                                  
                                                                                     
            is_torch_dist = ckpt_format == "torch_dist"
            if args.use_distributed_optimizer and not is_torch_dist:
                                                                  
                                                                                               
                assert not is_torch_dist
                tracker_filename = get_checkpoint_tracker_filename(load_dir)
                iteration, release = read_metadata(tracker_filename)
                model_checkpoint_name = \
                    get_checkpoint_name(load_dir, iteration, release)
                optim_checkpoint_name = \
                    get_distributed_optimizer_checkpoint_name(
                        model_checkpoint_name)
                optimizer.load_parameter_state(optim_checkpoint_name,
                                               update_legacy_format=args.ckpt_convert_update_legacy_dist_opt_format)

                             
            if opt_param_scheduler is not None:
                if 'lr_scheduler' in state_dict:                        
                    opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
                else:
                    opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
        except KeyError as e:
            print_rank_0('Unable to load optimizer from checkpoint {}. '
                         'Specify --no-load-optim or --finetune to prevent '
                         'attempting to load the optimizer state, '
                         'exiting ...'.format(checkpoint_name))
            raise e
    else:
        if (args.fp16 or args.bf16) and optimizer is not None:
            optimizer.reload_model_params()

                 
    try:
        if 'rerun_state_machine' in state_dict:
            get_rerun_state_machine().load_state_dict(state_dict['rerun_state_machine'])
    except Exception as e:
        print(f"Unable to restore RerunMachine from checkpoint: {e}")
        sys.exit()

                 
    if not release and not args.finetune and not args.no_load_rng:
        try:
            if 'rng_state' in state_dict:
                                                         
                if args.data_parallel_random_init:
                    rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
                else:
                    rng_state = state_dict['rng_state'][0]
                random.setstate(rng_state['random_rng_state'])
                np.random.set_state(rng_state['np_rng_state'])
                torch.set_rng_state(rng_state['torch_rng_state'])
                torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
                                              
                if not rng_state['rng_tracker_states']:
                    raise KeyError
                tensor_parallel.get_cuda_rng_tracker().set_states(
                    rng_state['rng_tracker_states'])
            else:                          
                random.setstate(state_dict['random_rng_state'])
                np.random.set_state(state_dict['np_rng_state'])
                torch.set_rng_state(state_dict['torch_rng_state'])
                torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
                                              
                if not state_dict['rng_tracker_states']:
                    raise KeyError
                tensor_parallel.get_cuda_rng_tracker().set_states(
                    state_dict['rng_tracker_states'])
        except KeyError:
            print_rank_0('Unable to load rng state from checkpoint {}. '
                         'Specify --no-load-rng or --finetune to prevent '
                         'attempting to load the rng state, '
                         'exiting ...'.format(checkpoint_name))
            sys.exit()

                                                                                    
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    print_rank_0(f'  successfully loaded checkpoint from {load_dir} '
                 f'[ t {mpu.get_tensor_model_parallel_rank() + 1}/{mpu.get_tensor_model_parallel_world_size()}, '
                 f'p {mpu.get_pipeline_model_parallel_rank() + 1}/{mpu.get_pipeline_model_parallel_world_size()} ] '
                 f'at iteration {iteration}')

                                               
    if not torch.distributed.is_initialized() \
       or is_last_rank():
        wandb_utils.on_load_checkpoint_success(checkpoint_name, load_dir)

    torch.cuda.empty_cache()

    if iteration > 0:
                                                 
        is_local_chkpt = (ckpt_type == CheckpointType.LOCAL)
        ft_integration.on_checkpoint_loaded(is_local_chkpt=is_local_chkpt)

    return iteration, num_floating_point_operations_so_far
