# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import enum
import functools
import os
import queue
import re
import shutil
import threading
import time
from concurrent.futures import Future
from typing import Any

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed.checkpoint import (
    HuggingFaceStorageReader,
    HuggingFaceStorageWriter,
)
from torch.distributed.checkpoint._consolidate_hf_safetensors import (
    consolidate_safetensors_files_on_every_rank,
)
from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions
from torch.distributed.checkpoint.state_dict import (
    get_model_state_dict,
    set_model_state_dict,
    StateDictOptions,
)
from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType
from torch.distributed.checkpoint.stateful import Stateful

from torchtitan.components.dataloader import BaseDataLoader
from torchtitan.components.ft import FTManager
from torchtitan.components.lr_scheduler import LRSchedulersContainer
from torchtitan.components.optimizer import OptimizersContainer
from torchtitan.config import Checkpoint as CheckpointConfig, TORCH_DTYPE_MAP
from torchtitan.protocols import BaseStateDictAdapter
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import GarbageCollection


MODEL = "model"
OPTIMIZER = "optimizer"
LR_SCHEDULER = "lr_scheduler"
DATALOADER = "dataloader"
TRAIN_STATE = "train_state"


class AsyncMode(str, enum.Enum):
    DISABLED = "disabled"
    ASYNC = "async"
    ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem"


class ModelWrapper(Stateful):
    def __init__(self, model: nn.Module | list[nn.Module]) -> None:
        self.model = [model] if isinstance(model, nn.Module) else model
        self.cache_state_dict = self._get_state_dict()

    def _get_state_dict(self) -> dict[str, Any]:
        state_dict = {
            k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
        }
        return state_dict

    def state_dict(self) -> dict[str, Any]:
        return self.cache_state_dict

    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        func = functools.partial(
            set_model_state_dict,
            model_state_dict=state_dict,
            options=StateDictOptions(strict=False),
        )
        list(map(func, self.model))
        # `set_model_state_dict()` does change the keys of the input state_dict,
        # we will need to reinitialize the cache_state_dict.
        self.cache_state_dict = self._get_state_dict()


class Terminate:
    pass


class SaveDone:
    pass


class _OptimizerStateLoadShim(Stateful):
    """Wrap optimizers to normalize state for backward-compatible loading."""

    def __init__(
        self,
        optimizer: OptimizersContainer,
        *,
        drop_projector_state: bool = False,
        drop_projector_tokens: tuple[str, ...] | None = None,
        requires_projector_state: bool = False,
        drop_param_group_keys: tuple[str, ...] | None = None,
        drop_state_keys: tuple[str, ...] | None = None,
    ) -> None:
        self._optimizer = optimizer
        self._drop_tokens: tuple[str, ...] = drop_projector_tokens or tuple()
        self._requires_projector_state = requires_projector_state
        self._drop_param_group_keys: tuple[str, ...] = drop_param_group_keys or tuple()
        if drop_state_keys:
            self._drop_tokens = self._drop_tokens + drop_state_keys
        if drop_projector_state:
            self._drop_tokens = self._drop_tokens + (
                "projector_meta",
                "projector_basis",
                "initial_projector",
            )
        self._cached_tokens: dict[str, Any] = {}
        self._cached_param_group_values: list[dict[str, Any]] = []
        self._param_group_fallbacks: list[dict[str, Any]] = []
        self._optimizer_defaults: dict[str, Any] = getattr(optimizer, "defaults", {})
        if self._drop_param_group_keys:
            for group in getattr(optimizer, "param_groups", []):
                fallback: dict[str, Any] = {}
                for key in self._drop_param_group_keys:
                    if key in group:
                        fallback[key] = group[key]
                    elif key in self._optimizer_defaults:
                        fallback[key] = self._optimizer_defaults[key]
                self._param_group_fallbacks.append(fallback)

    def _resolve_param_group_fallback(self, group_idx: int, key: str) -> Any:
        if group_idx < len(self._param_group_fallbacks):
            fallback = self._param_group_fallbacks[group_idx]
            if key in fallback:
                return fallback[key]
        if key in self._optimizer_defaults:
            return self._optimizer_defaults[key]
        return False

    def state_dict(self) -> dict[str, Any]:
        state_dict = self._optimizer.state_dict()
        if not isinstance(state_dict, dict):
            return state_dict

        # Strip optional GaLore-only param_group metadata so older checkpoints load.
        if self._drop_param_group_keys and isinstance(
            state_dict.get("param_groups"), list
        ):
            self._cached_param_group_values = []
            param_groups: list[dict[str, Any]] = []
            for group in state_dict["param_groups"]:
                cached_values = {
                    key: group[key]
                    for key in self._drop_param_group_keys
                    if key in group
                }
                self._cached_param_group_values.append(cached_values)
                patched_group = dict(group)
                for key in self._drop_param_group_keys:
                    patched_group.pop(key, None)
                param_groups.append(patched_group)

            patched_state_dict = dict(state_dict)
            patched_state_dict["param_groups"] = param_groups
        else:
            patched_state_dict = state_dict

        # Also drop any flattened param_group entries that carry the same keys (e.g. DCP
        # flattens param_groups.*.*.use_error_feedback).
        if self._drop_param_group_keys:
            patched_state_dict = {
                key: value
                for key, value in patched_state_dict.items()
                if not any(
                    f".{drop_key}" in key for drop_key in self._drop_param_group_keys
                )
            }

        if not self._drop_tokens:
            return patched_state_dict

        filtered_state: dict[str, Any] = {}
        self._cached_tokens = {}
        for key, value in patched_state_dict.items():
            if any(token in key for token in self._drop_tokens):
                self._cached_tokens[key] = value
                continue
            filtered_state[key] = value
        return filtered_state

    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        if self._cached_tokens:
            patched_state_dict = dict(state_dict)
            for key, value in self._cached_tokens.items():
                patched_state_dict.setdefault(key, value)
            state_dict = patched_state_dict
            self._cached_tokens = {}

        # Inject missing GaLore param_group metadata with safe defaults so unflattening succeeds
        # even when older checkpoints do not carry these keys.
        if self._drop_param_group_keys:
            # Update list-style param_groups entries.
            if isinstance(state_dict.get("param_groups"), list):
                for idx, group in enumerate(state_dict["param_groups"]):
                    for key in self._drop_param_group_keys:
                        group.setdefault(
                            key, self._resolve_param_group_fallback(idx, key)
                        )

            # Update flattened param_group entries.
            extra_entries: dict[str, Any] = {}
            group_indices = set()
            for key in list(state_dict.keys()):
                if not key.startswith("param_groups."):
                    continue
                base, _, attr = key.rpartition(".")
                if not base:
                    continue
                index_str = base.split(".")[1] if "." in base else ""
                if index_str.isdigit():
                    group_indices.add(int(index_str))
                if attr in self._drop_param_group_keys:
                    continue
                for drop_key in self._drop_param_group_keys:
                    missing_key = f"{base}.{drop_key}"
                    if missing_key not in state_dict:
                        group_idx = int(index_str) if index_str.isdigit() else 0
                        extra_entries[missing_key] = self._resolve_param_group_fallback(
                            group_idx,
                            drop_key,
                        )
            if not group_indices and isinstance(state_dict.get("param_groups"), list):
                group_indices = set(range(len(state_dict["param_groups"])))
            for idx in sorted(group_indices):
                base_key = f"param_groups.{idx}"
                for drop_key in self._drop_param_group_keys:
                    missing_key = f"{base_key}.{drop_key}"
                    if missing_key not in state_dict:
                        extra_entries[missing_key] = self._resolve_param_group_fallback(
                            idx,
                            drop_key,
                        )
            if extra_entries:
                patched = dict(state_dict)
                patched.update(extra_entries)
                state_dict = patched
        self._optimizer.load_state_dict(state_dict)

        if not self._drop_param_group_keys:
            return

        loaded_param_groups = []
        try:
            loaded_param_groups = state_dict.get("param_groups", [])  # type: ignore[attr-defined]
        except Exception:
            loaded_param_groups = []

        for idx, group in enumerate(getattr(self._optimizer, "param_groups", [])):
            cached_values = (
                self._cached_param_group_values[idx]
                if idx < len(self._cached_param_group_values)
                else {}
            )
            loaded_values = (
                loaded_param_groups[idx] if idx < len(loaded_param_groups) else {}
            )
            for key in self._drop_param_group_keys:
                if key in loaded_values:
                    group[key] = loaded_values[key]
                    continue
                if key in group:
                    continue
                if key in cached_values:
                    group[key] = cached_values[key]
                    continue
                group[key] = self._resolve_param_group_fallback(idx, key)

        self._cached_param_group_values = []

    def maybe_warm_projector_state(self) -> None:
        if not self._requires_projector_state:
            return
        optimizers = getattr(self._optimizer, "optimizers", None)
        if not optimizers:
            return
        try:
            from torchtitan.experiments.fl.optimizers.galore_global import GaLoreGlobal
        except Exception:  # pragma: no cover - optional dependency
            return

        for inner in optimizers:
            if isinstance(inner, GaLoreGlobal):
                try:
                    inner._repair_projector_states()
                except Exception as exc:  # pragma: no cover - safety net
                    logger.warning(
                        "Unable to refresh GaLoreGlobal projector state after load: %s",
                        exc,
                    )
                    continue

                total_params = 0
                projector_ready = 0
                for group in inner.param_groups:
                    for param in group.get("params", []):
                        if not isinstance(param, torch.nn.Parameter):
                            continue
                        total_params += 1
                        param_state = inner.state.get(param)
                        if (
                            param_state
                            and param_state.get("projector_basis") is not None
                        ):
                            print(
                                "[GaLoreGlobal][checkpoint] parameter has projector basis:"
                            )
                            print(param_state.get("projector_basis"))
                            projector_ready += 1

                print(
                    "[GaLoreGlobal][checkpoint] projector bases present "
                    f"for {projector_ready}/{total_params} parameters after load"
                )

    def __getattr__(self, name: str) -> Any:  # pragma: no cover - simple delegation
        return getattr(self._optimizer, name)


class _OptimizerStateSaveShim(Stateful):
    """Wrap optimizers to drop local-only optimizer state before saving."""

    def __init__(
        self, optimizer: OptimizersContainer, drop_state_keys: tuple[str, ...]
    ) -> None:
        self._optimizer = optimizer
        self._drop_state_keys = drop_state_keys

    def state_dict(self) -> dict[str, Any]:
        state_dict = self._optimizer.state_dict()
        if not isinstance(state_dict, dict):
            return state_dict

        if not self._drop_state_keys:
            return state_dict

        filtered_state: dict[str, Any] = {}
        for key, value in state_dict.items():
            if any(token in key for token in self._drop_state_keys):
                continue
            filtered_state[key] = value
        return filtered_state

    def __getattr__(self, name: str) -> Any:  # pragma: no cover - simple delegation
        return getattr(self._optimizer, name)


def _optimizer_requires_projector_basis(
    optimizer: OptimizersContainer,
) -> bool:
    """Return True when optimizer needs projector bases preserved for resume."""

    optimizers = getattr(optimizer, "optimizers", None)
    if not optimizers:
        return False

    try:  # Avoid hard dependency when GaLore experiments are absent.
        from torchtitan.experiments.fl.optimizers.galore_global import GaLoreGlobal
    except Exception:  # pragma: no cover - optional dependency
        return False

    return any(isinstance(inner, GaLoreGlobal) for inner in optimizers)


def purge_thread(purge_queue: queue.Queue):
    """Thread to purge the old checkpoints.

    This is only used when keep_latest_k > 0.

    Args:
        purge_queue (queue.Queue): The queue to receive the path to purge and Terminate signal.
    """
    try:
        while True:
            path = purge_queue.get()
            if isinstance(path, Terminate):
                return
            assert isinstance(path, str)
            logger.info("Checkpointer is deleting %s.", path)
            begin = time.monotonic()
            shutil.rmtree(path, ignore_errors=True)
            logger.info(
                "Checkpointer deleted %s in %.2f seconds.",
                path,
                time.monotonic() - begin,
            )
    finally:
        logger.info("Destroying the purge thread.")


class CheckpointManager:
    """This class manages the checkpointing logic for the TorchTitan trainer.


    Note: Pipeline Parallelism and Virtual Stages

    1. even for simple PP schedules, there is a separate optimizer each PP rank.
    rank0's optimizer would have a param_group[0] which refers to layers.0 in the original
    model.  rank1's would _also_ have a param_group[0], since it's index based, but
    referring to layers.1.  When saving, these collide and one of them is lost.  Then when
    reloading, only one stage can restore its optimizer states, others will error.

        The solution to this problem is optimizer flattening: it landed in #127071 and is
        enabled in TorchTitan by passing the 'flatten_optimizer_state_dict' kwarg to DCP
        functions called in the OptimizerContainer.
        See PR #127071 (https://github.com/pytorch/pytorch/pull/127071) for the example of
        a flattening state_dict.

    2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds
    challenge (1) by also requiring us to reason about multiple 'optim' objects locally.

        We solve this in the Model and Optimizer wrapper classes by flattening the state dicts
        from each object into one state dict before saving/loading. We rely on the individual
        state_dicts to not collide, which is guaranteed for the model by correct pipeline
        splitting and for the optimizer by the flattening support described in (1).

    3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers
    with the assumption that all lr_schedulers have the same state_dict.

    Note: TorchFT checkpointing flow

    There are two types of checkpoints: when TorchFT is enabled: 1) the full persistent
    checkpoint, 2) the per-replica checkpoint.

    The full persistent checkpoint is saved by the replica with
    ``ft_manager.participating_rank() == 0``. It contains everything including the model,
    optimizer, lr_scheduler, dataloader, and train_state. Right now the full persistent
    checkpoint is loaded by all replicas. However, we can optimize it to only load if
    there are no other alive replicas.

    The per-replica checkpoint contains only the dataloader and is saved/loaded by all
    replicas to/from the its own folder. The folder name is prefixed with the ft_replica_id.

    Args:
        dataloader (DataLoader): The dataloader used to load the data.
        model_parts (List[nn.Module]): List of model parts to be optimized.
        optimizers (OptimizersContainer): The optimizers used to optimize the model.
        lr_schedulers (LRSchedulersContainer): The lr schedulers used to optimize the model.
        states (Dict[str, Any]): The states that need to be saved, other than the
            previous 4 components.
        checkpoint_config (Checkpoint): The config used to configure the checkpointing.
        base_folder (str): The base folder to save the checkpoint. Will be concatenated
            with checkpoint_config.folder
        sd_adapter (Optional[type[BaseStateDictAdapter]]): The adapter used to convert model state
            dicts between native format and other formats.
        ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.

    """

    def __init__(
        self,
        dataloader: BaseDataLoader | None,
        model_parts: list[nn.Module],
        optimizers: OptimizersContainer,
        lr_schedulers: LRSchedulersContainer,
        states: dict[str, Any],
        checkpoint_config: CheckpointConfig,
        sd_adapter: BaseStateDictAdapter | None,
        base_folder: str = "",
        ft_manager: FTManager | None = None,
    ) -> None:
        self.enable = checkpoint_config.enable
        self.load_only = checkpoint_config.load_only

        self.ft_manager = (
            ft_manager.manager if ft_manager and ft_manager.enabled else None
        )
        if self.ft_manager:
            optimizers.init_cache_state_dict()

            def state_dict():
                ret = {}
                for k, v in self.states.items():
                    if k in {
                        MODEL,
                        OPTIMIZER,
                        LR_SCHEDULER,
                        TRAIN_STATE,
                    }:
                        ret[k] = v.state_dict()
                return ret

            def load_state_dict(state_dict):
                assert state_dict is not None
                for k, v in state_dict.items():
                    self.states[k].load_state_dict(v)

            self.ft_manager.set_state_dict_fns(load_state_dict, state_dict)
            self.ft_replica_id = ft_manager.replica_id

        async_mode = checkpoint_config.async_mode.lower()
        self.enable_staging = (
            self.enable and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
        ) or self.ft_manager

        if not self.enable and self.ft_manager is None:
            return

        self.states = states
        self.states.update(
            {
                MODEL: ModelWrapper(model_parts),
                OPTIMIZER: optimizers,
                DATALOADER: dataloader,
                LR_SCHEDULER: lr_schedulers,
            }
        )
        self._optimizer_state_shim: _OptimizerStateLoadShim | None = None
        self.ft_states = {DATALOADER: dataloader}
        self.ft_dataloader_loaded = False

        self.staging = False
        self.sending_to_checkpoint_mp = False
        self.staging_id = None
        self.cpu_offload_state_dict = None
        self.stager = None

        self.folder = os.path.join(base_folder, checkpoint_config.folder)

        # Checkpoint policy related fields.
        self.initial_load_model_only = checkpoint_config.initial_load_model_only
        self.initial_load_in_hf = checkpoint_config.initial_load_in_hf
        self.initial_load_path = checkpoint_config.initial_load_path
        self.last_save_model_only = checkpoint_config.last_save_model_only
        self.last_save_in_hf = checkpoint_config.last_save_in_hf
        if self.last_save_in_hf:
            assert (
                sd_adapter is not None
            ), "job_config.checkpoint.last_save_in_hf is True, but sd_adapter is not provided."
        self.sd_adapter = sd_adapter
        self.export_dtype = TORCH_DTYPE_MAP[checkpoint_config.export_dtype]
        self.exclude_from_loading = checkpoint_config.exclude_from_loading
        self.interval = checkpoint_config.interval
        self.enable_first_step_checkpoint = (
            checkpoint_config.enable_first_step_checkpoint
        )

        # Async checkpoint related fields.
        async_mode = checkpoint_config.async_mode.lower()
        if (
            async_mode == AsyncMode.ASYNC
            or async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
            or self.ft_manager
        ):
            self.pg = dist.new_group(backend="gloo")

        self.keep_latest_k = checkpoint_config.keep_latest_k
        if self.keep_latest_k > 0:
            if self.keep_latest_k == 1:
                raise ValueError(
                    "We need to maintain at least 2 checkpoint replicas, "
                    "as the last one may be in the process of being saved."
                )
            self.purge_queue = queue.Queue()
            self.purge_thread = threading.Thread(
                target=purge_thread, args=(self.purge_queue,), daemon=True
            )
            self.purge_thread.start()
        else:
            self.purge_thread = None

        self.mp = None
        self.staging_future = None
        self.save_future = None
        if async_mode == AsyncMode.DISABLED:
            self.async_mode = AsyncMode.DISABLED
        elif async_mode == AsyncMode.ASYNC:
            self.async_mode = AsyncMode.ASYNC
        elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
            self.async_mode = AsyncMode.ASYNC_WITH_PINNED_MEM
        else:
            raise ValueError(
                f"Unknown checkpoint async_mode {checkpoint_config.async_mode}"
            )

        logger.info(
            f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}"
        )

    def __del__(self):
        self.close()

    def close(self):
        if hasattr(self, "enable") and self.enable:
            if hasattr(self, "mp") and self.mp and self.mp.is_alive():
                self.mp_queue_send.put(Terminate())
                self.mp.join()
            if (
                hasattr(self, "purge_thread")
                and self.purge_thread
                and self.purge_thread.is_alive()
            ):
                self.purge_queue.put(Terminate())
                self.purge_thread.join()

            if self.stager is not None:
                self.stager.close()

    @torch.no_grad()
    def dcp_save(
        self,
        state_dict: dict[str, Any],
        checkpoint_id: str,
        async_mode: AsyncMode,
        enable_garbage_collection: bool = False,
        to_hf: bool = False,
    ) -> Future | None:
        """Save the checkpoint with dcp.
        Args:
            state_dict (dict): The state dict to save.
            checkpoint_id (str): The checkpoint id to save.
            async_mode (AsyncMode): Whether the checkpoint is async.
            enable_garbage_collection (bool): Whether to enable garbage collection after save.
            to_hf (bool): Whether to save in HF model definition and safetensors format.

        Returns:
            Future: The future object if the checkpoint is async, otherwise None.
        """

        ret: Future | None = None

        storage_writer: HuggingFaceStorageWriter | None = None
        checkpoint_save_id: str | None = None
        if to_hf:
            assert (
                self.sd_adapter is not None
            ), "trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
            state_dict = self.sd_adapter.to_hf(state_dict)

            fqn_to_index_mapping = self.sd_adapter.fqn_to_index_mapping
            if fqn_to_index_mapping:
                storage_writer = HuggingFaceStorageWriter(
                    path=os.path.join(checkpoint_id, "sharded"),
                    save_distributed=True,
                    fqn_to_index_mapping=fqn_to_index_mapping,
                    enable_consolidation=False,
                )
            else:
                # the reason for only enabling consolidation if there is
                # no mapping is because no mapping implies that we save all fqns
                # to one file. This means we only need one rank to consolidate.
                # Otherwise we should use consolidate_safetensors_files_on_every_rank
                storage_writer = HuggingFaceStorageWriter(
                    path=checkpoint_id,
                    save_distributed=True,
                    enable_consolidation=True,
                )

        else:
            checkpoint_save_id = checkpoint_id

        if async_mode == AsyncMode.ASYNC:
            ret = dcp.async_save(
                state_dict,
                storage_writer=storage_writer,
                checkpoint_id=checkpoint_save_id,
                process_group=self.pg,
            )
        elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
            ret = dcp.async_save(
                state_dict,
                storage_writer=storage_writer,
                checkpoint_id=checkpoint_save_id,
                process_group=self.pg,
                async_checkpointer_type=AsyncCheckpointerType.PROCESS,
                async_stager=self.stager,
            )
        else:
            ret = dcp.save(
                state_dict,
                storage_writer=storage_writer,
                checkpoint_id=checkpoint_save_id,
            )

        if to_hf and self.sd_adapter.fqn_to_index_mapping:
            consolidate_safetensors_files_on_every_rank(
                input_dir=os.path.join(checkpoint_id, "sharded"),
                output_dir=checkpoint_id,
                fqn_to_index_mapping=self.sd_adapter.fqn_to_index_mapping,
                num_threads=5,
            )

        if enable_garbage_collection:
            GarbageCollection.collect("GC collection invoked by checkpointer.")

        return ret

    def dcp_load(
        self,
        state_dict: dict[str, Any],
        checkpoint_id: str,
        from_hf: bool,
    ) -> None:
        """Load the checkpoint with dcp.
        Args:
            state_dict (dict): The state dict to load.
            checkpoint_id (str): The checkpoint id to load.
            from_hf (bool): Whether to load from HuggingFace checkpoint with
                its own model definition and safetensors format.
        """

        if from_hf:
            assert (
                self.sd_adapter is not None
            ), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided."
            hf_state_dict = self.sd_adapter.to_hf(state_dict)

            dcp.load(
                hf_state_dict,
                storage_reader=HuggingFaceStorageReader(path=checkpoint_id),
            )

            state_dict = self.sd_adapter.from_hf(hf_state_dict)
            self.states[MODEL].load_state_dict(state_dict)
        else:
            dcp.load(state_dict, checkpoint_id=checkpoint_id)

            # TODO: Since we flatten the model states in state_dict, we need to
            # manually call load_state_dict() for the model. Need to fix this.
            if MODEL in self.states:
                self.states[MODEL].load_state_dict(state_dict)

    @torch.no_grad()
    def save(self, curr_step: int, last_step: bool = False) -> None:
        """Save the checkpoint for the current step.

        This function will save the checkpoint for the current step. If ``last_step`` is
        true, it will save the checkpoint even if the interval has not been reached.
        This only happens when train_state.step == job_config.training.steps, or
        for initial seed checkpoint.

        Args:
            curr_step (int): The current step.
            last_step (bool, optional): Whether this is the last step of training.

        Returns:
            None
        """

        if self.ft_manager:
            self._ft_save(curr_step)

        if not self._should_save(curr_step, last_step):
            return

        begin = time.monotonic()
        if not self.ft_manager or self.ft_manager.participating_rank() == 0:
            logger.info("Saving the checkpoint (or staging if async is enabled).")
            checkpoint_id = self._create_checkpoint_id(curr_step)
            self._async_wait()
            # This GC is called for async checkpoint as it is useless to do
            # GC right after async_save -- the CPU memory is not able to be
            # freed until _async_wait()
            if last_step:
                self._save_last_step(curr_step)
                return

            states = self._flattened_model_states_sd()
            if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
                GarbageCollection.collect("GC collection invoked by checkpointer.")
                if self.stager is None:
                    self.stager = DefaultStager(StagingOptions(True, True, True, True))
                result = self.dcp_save(
                    states,
                    checkpoint_id=checkpoint_id,
                    async_mode=self.async_mode,
                )
                self.save_future = result.upload_completion
                self.staging_future = result.staging_completion
                self.staging = True
            elif self.async_mode == AsyncMode.ASYNC:
                GarbageCollection.collect("GC collection invoked by checkpointer.")
                self.save_future = self.dcp_save(
                    states, checkpoint_id=checkpoint_id, async_mode=self.async_mode
                )
                GarbageCollection.collect("GC collection invoked by checkpointer.")
            else:
                self.dcp_save(
                    states,
                    checkpoint_id=checkpoint_id,
                    async_mode=AsyncMode.DISABLED,
                    enable_garbage_collection=True,
                )
            self._purge_stale_checkpoints()

            logger.info(
                "Finished saving the checkpoint (or staging if async is enabled)"
                f"in {time.monotonic() - begin:.2f} seconds."
            )
        elif self.ft_manager:
            logger.info(
                "Replica %d doesn't save checkpoint.",
                self.ft_manager.participating_rank(),
            )

    @torch.no_grad()
    def load(self, step: int = -1) -> bool:
        """Load the checkpoint for the given step.

        This function will load the checkpoint for the given step. If ``step`` is -1, it
        will load the latest checkpoint. If the checkpoint does not exist, it will return
        False and load nothing.

        Args:
            step (int, optional): The step to load the checkpoint for. Defaults to -1.

        Returns:
            bool: Whether the checkpoint was loaded successfully.
        """

        target_step: int | None = None
        model_only = False
        from_hf = False
        checkpoint_id: str | None = None
        should_return_false = False

        if self.enable:
            if not os.path.exists(self.folder):
                model_only = self.initial_load_model_only
                from_hf = self.initial_load_in_hf
                if from_hf:
                    assert (
                        model_only
                    ), "Only model can be loaded when loading from HF's safetensors checkpoint."
                if self.initial_load_path:
                    checkpoint_id = self.initial_load_path
                    if not os.path.isdir(checkpoint_id):
                        raise ValueError(
                            "checkpoint.initial_load_path is specified but the path is not valid."
                        )
                    if from_hf:
                        logger.info(
                            f"loading from HF safetensors from --checkpoint.initial_load_path: {self.initial_load_path}"
                        )
                elif from_hf:
                    checkpoint_id = self.sd_adapter.hf_assets_path
                    if not os.path.isdir(checkpoint_id):
                        raise ValueError(
                            "model.hf_assets_path is being used to load HF weights but the path is not valid. \
                        Either make sure hf_assets_path is correct or provide a valid checkpoint.initial_load_path"
                        )
                    logger.info(
                        f"loading HF safetensors from --model.hf_assets_path: {self.sd_adapter.hf_assets_path}"
                    )
                else:
                    should_return_false = True
            else:
                if self.initial_load_path:
                    logger.warning(
                        "checkpoint.initial_load_path is provided but the checkpoint.folder exists. "
                        f"Checkpointer will use the checkpoints from the checkpoint.folder {self.folder}."
                    )
                if self.initial_load_in_hf:
                    logger.warning(
                        "checkpoint.initial_load_in_hf is True but the checkpoint.folder exists. "
                        "Checkpointer will not load from HF safetensors"
                    )
                step = self._find_load_step() if step == -1 else step
                if step == -1:
                    should_return_false = True
                else:
                    target_step = step
                    model_only = step == 0
                    checkpoint_id = self._create_checkpoint_id(step)

                    if not os.path.isdir(checkpoint_id):
                        raise FileNotFoundError(
                            f"--checkpoint.load_step={step} but checkpoint {checkpoint_id} is not found."
                        )

        if self.ft_manager:
            self._ft_load(target_step)

        if not self.enable or should_return_false:
            return False

        assert checkpoint_id is not None

        if not os.path.exists(self.folder):
            # checkpoint_id points to the requested initial load path; ensure it exists
            if not os.path.isdir(checkpoint_id):
                raise FileNotFoundError(
                    f"checkpoint initial load path {checkpoint_id} is not found."
                )

        logger.info(f"Loading the checkpoint from {checkpoint_id}.")
        begin = time.monotonic()
        states = self._states_to_load(model_only)
        logger.info(f"[RESUME DEBUG] States to load keys: {list(states.keys())}")
        logger.info(f"[RESUME DEBUG] Model only: {model_only}")
        states = self._prepare_states_for_load(states)
        self.dcp_load(
            states,
            checkpoint_id=checkpoint_id,
            from_hf=from_hf,
        )
        shim = getattr(self, "_optimizer_state_shim", None)
        if isinstance(shim, _OptimizerStateLoadShim):
            shim.maybe_warm_projector_state()
            self._optimizer_state_shim = None
        GarbageCollection.collect("GC collection for checkpoint loading.")
        logger.info(
            f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
        )
        return True

    def maybe_wait_for_staging(self) -> None:
        """Wait for the staging to finish if it is enabled.

        This function will wait for staging to finish. The staging is only enabled
        with ``async_checkpoint_with_pinned_memory``.
        """
        if self.enable_staging and self.staging:
            self.staging_future.result()
            self.staging = False

    def _find_load_step(self, folder: str = "", *, max_step: int | None = None) -> int:
        """Find the step to load the checkpoint for.

        Args:
            folder (str, optional): The folder to find the checkpoint for. If ``folder``
            is "", then ``self.folder`` will be used.
            max_step (int, optional): Upper bound for the checkpoint step to consider.

        Returns:
            int: The step to load the checkpoint for.
        """
        folder = folder if folder else self.folder
        pattern = r"step-(\d+)"
        step_counts = []

        if not os.path.isdir(folder):
            return -1

        for filename in os.listdir(folder):
            match = re.search(pattern, filename)

            # Drop local-only optimizer tensors (e.g., error_feedback) before saving so they are
            # neither synchronized nor required when loading.
            if not match:
                continue
            step_value = int(match.group(1))
            if max_step is not None and step_value > max_step:
                continue
            dcp_metadata_probe = os.path.join(folder, filename, ".metadata")
            safetensors_metadata_probe = os.path.join(
                folder, filename, "model.safetensors.index.json"
            )
            if os.path.isfile(dcp_metadata_probe) or os.path.isfile(
                safetensors_metadata_probe
            ):
                step_counts.append(step_value)
        if not step_counts:
            return -1
        return max(step_counts)

    def _ft_folder(self) -> str:
        return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")

    def _create_checkpoint_id(self, step: int, folder: str = "") -> str:
        folder = folder if folder else self.folder
        return os.path.join(folder, f"step-{step}")

    def _ft_save(self, step: int) -> None:
        begin = time.monotonic()
        self._async_wait()
        checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
        self.save_future = self.dcp_save(
            self.ft_states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.ASYNC
        )
        logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")

    def _ft_load(self, target_step: int | None = None) -> None:
        self.ft_dataloader_loaded = False
        if target_step is None:
            step = self._find_load_step(folder=self._ft_folder())
        else:
            step = self._find_load_step(folder=self._ft_folder(), max_step=target_step)
            if step == -1:
                logger.warning(
                    "FT checkpoint matching step %s not found; falling back to latest available.",
                    target_step,
                )
                step = self._find_load_step(folder=self._ft_folder())
        if step == -1:
            logger.info("No FT checkpoint found to load.")
            return

        begin = time.monotonic()
        logger.info(f"Loading the FT checkpoint at step {step}.")
        checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
        self.dcp_load(
            self.ft_states,
            checkpoint_id=checkpoint_id,
            # FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader.
            from_hf=False,
        )
        GarbageCollection.collect("GC collection for checkpoint loading.")
        self.ft_dataloader_loaded = True
        logger.info(
            f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
        )

    def _flattened_model_states_sd(
        self,
        state_dict: dict[str, Any] | None = None,
        *,
        wrap_optimizer_for_save: bool = True,
    ) -> dict[str, Any]:
        """Flatten the model states into a single dictionary.

        Note that other states, such as optimizer states, are not flattened.

        Args:
            state_dict (dict[str, Any] | None): Optional state dictionary to flatten.
            wrap_optimizer_for_save (bool): Whether to wrap optimizer state to drop
                save-only keys.
        """
        states = state_dict if state_dict is not None else self.states

        sd: dict[str, Any] = {}
        for key, value in states.items():
            if key == MODEL:
                continue
            if (
                key == OPTIMIZER
                and isinstance(value, OptimizersContainer)
                and wrap_optimizer_for_save
            ):
                sd[key] = _OptimizerStateSaveShim(value, ("error_feedback",))
            else:
                sd[key] = value
        if MODEL in states:
            sd.update(states[MODEL].state_dict())
        return sd

    def _states_to_load(self, model_only: bool) -> dict[str, Any]:
        """Determines which states to load for the given step.

        This API is used to determine which states to load based on the
        configurations.

        Args:
            model_only (bool): Whether to load the model only.

        Returns:
            Dict[str, Any]: The states to load for the given step.
        """
        # For the first step, we will only load the model.
        if model_only:
            return self.states[MODEL].state_dict()

        for exclude_key in self.exclude_from_loading:
            if exclude_key not in self.states:
                raise ValueError(f"{exclude_key} not found in state_dict.")

        states_to_load = {
            k: v for k, v in self.states.items() if k not in self.exclude_from_loading
        }

        states_to_load = self._flattened_model_states_sd(
            states_to_load,
            wrap_optimizer_for_save=False,
        )

        if self.ft_manager and self.ft_dataloader_loaded:
            states_to_load.pop(DATALOADER, None)

        return states_to_load

    def _prepare_states_for_load(self, states: dict[str, Any]) -> dict[str, Any]:
        """Apply backward-compat shims when loading checkpoints."""

        optimizer_state = states.get(OPTIMIZER)
        self._optimizer_state_shim = None
        if isinstance(optimizer_state, OptimizersContainer):
            wrapped_states = dict(states)
            requires_projector_basis = _optimizer_requires_projector_basis(
                optimizer_state
            )
            drop_keys = {"use_error_feedback"}
            optimizers = getattr(optimizer_state, "optimizers", [])
            if any(
                opt.__class__.__name__ in {"GaLore", "GaLoreGlobal"}
                for opt in optimizers
            ):
                drop_keys.add("vs")
                drop_keys.add("qhm_outside_projection")
            drop_param_group_keys = tuple(sorted(drop_keys))
            drop_tokens = (
                ("initial_projector", "initial_projector_mode")
                if not requires_projector_basis
                else tuple()
            )
            wrapped_states[OPTIMIZER] = _OptimizerStateLoadShim(
                optimizer_state,
                drop_projector_state=not requires_projector_basis,
                drop_projector_tokens=drop_tokens,
                requires_projector_state=requires_projector_basis,
                drop_param_group_keys=drop_param_group_keys,
                drop_state_keys=("error_feedback",),
            )
            self._optimizer_state_shim = wrapped_states[OPTIMIZER]
            return wrapped_states
        return states

    def _save_last_step(self, curr_step: int) -> None:
        # We only consider saving model only at the end of the training. So this
        # won't affect preemption and training resume. We also only allow dtype
        # conversion when we are checkpointing model only and the current dtype
        # is not the same as the export dtype at the end of the training.

        if self.last_save_model_only:
            states = self.states[MODEL].state_dict()

            if self.export_dtype != torch.float32:
                states = {k: v.to(self.export_dtype) for k, v in states.items()}
            logger.info(
                f"Saving a model only checkpoint in {self.export_dtype} at last step, step {curr_step}."
            )
        else:
            logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")
            states = self._flattened_model_states_sd()

        if self.last_save_in_hf:
            assert (
                self.last_save_model_only
            ), "Only model can be saved when saving in HF safetensors format."

        self.dcp_save(
            states,
            checkpoint_id=self._create_checkpoint_id(curr_step),
            async_mode=AsyncMode.DISABLED,
            enable_garbage_collection=True,
            to_hf=self.last_save_in_hf,
        )

    def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
        if not self.enable or self.load_only:
            return False

        if curr_step == 1 and self.enable_first_step_checkpoint:
            return True

        if last_step:
            return True

        if curr_step % self.interval == 0:
            return True

        return False

    def _async_wait(self) -> None:
        if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
            if self.save_future is not None:
                self.save_future.result()
        elif self.async_mode == AsyncMode.ASYNC or self.ft_manager is not None:
            if self.save_future is not None:
                self.save_future.result()
                self.save_future = None
        elif self.save_future is not None:
            raise RuntimeError(
                "self.save_future is not None, but self.async_mode is not enabled and fault tolerance is not active."
            )

    def _purge_stale_checkpoints(self):
        if (
            self.keep_latest_k > 0
            and dist.get_rank() == 0
            and os.path.isdir(self.folder)
            and (not self.ft_manager or self.ft_manager.participating_rank() == 0)
        ):
            discovered_checkpoints = []
            for filename in os.listdir(self.folder):
                match = re.search(r"step-(\d+)", filename)
                if match:
                    path = os.path.join(self.folder, filename)
                    discovered_checkpoints.append((int(match.group(1)), path))

            discovered_checkpoints.sort()
            to_delete = discovered_checkpoints[: -1 * self.keep_latest_k]

            for _, path in to_delete:
                assert self.purge_thread is not None
                self.purge_queue.put(path)
