"""Utility functions for FL and experiment management.

They assure compatibility with the Flower and wandb APIs.
"""

# ruff: noqa: ERA001
import ast
import copy
import threading
import time
import types
import warnings
from collections import OrderedDict, defaultdict
from collections.abc import Generator, Hashable, Iterable
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial, reduce
from itertools import chain, starmap
from logging import DEBUG, ERROR, WARNING
from pathlib import Path
from queue import Queue
from typing import Any, cast

import numpy as np
import psutil
import torch
from composer import Timestamp, Trainer
from composer.utils import dist
from flwr.common import (
    Config,
    NDArray,
    NDArrays,
    log,
)
from llmfoundry.command_utils.train import (
    validate_config,
)
from llmfoundry.utils.builders import (
    build_composer_model,
    build_tokenizer,
)
from llmfoundry.utils.config_utils import (
    TRAIN_CONFIG_KEYS,
    TrainConfig,
    make_dataclass_and_log_config,
    process_init_device,
)
from omegaconf import DictConfig, OmegaConf
from torch import device as device_type
from torch.distributed.checkpoint.state_dict import (
    ListDictValueType,
    OptimizerStateType,
    StateDictOptions,
    get_model_state_dict,
    set_model_state_dict,
    set_optimizer_state_dict,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
from torch.optim import Optimizer
from torch.optim.optimizer import StateDict

import ray
import wandb
from repo.conf.base_schema import BaseConfig
from repo.file_utils import load_model_parameters_from_file
from repo.masks_utils import ModelStateNames
from ray._private.internal_api import (
    free as ray_free,  # noqa: PLC2701  # type: ignore[reportPrivateImportUsage]
)


@dataclass
class ClientState:
    """Dataclass for client state."""

    local_steps_cumulative: int
    local_timestamp: dict[str, Any] = field(
        default_factory=lambda: {
            k: v
            for k, v in Timestamp().state_dict().items()
            if is_literal_for_ast(repr(v))
        },
    )
    steps_done: int = 0


class ParameterShapeMismatchError(Exception):
    """Exception raised when parameter shapes mismatch."""

    def __init__(
        self,
        remote_name: str,
        remote_shape: torch.Size,
        local_shape: torch.Size,
    ) -> None:
        """Initialize the exception."""
        msg = f"Shapes don't match for {remote_name}: {remote_shape} != {local_shape}"
        super().__init__(msg)


class NoOpContextManager:
    """A context manager that does nothing."""

    def __enter__(self) -> None:
        """Do nothing."""
        return

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: types.TracebackType | None,
    ) -> None:
        """Do nothing."""
        return


def create_partial_id_map(
    saved_groups: list[dict[str, Any]], groups: list[dict[str, Any]],
) -> dict:
    """Create a mapping between parameter IDs in saved groups and original groups.

    This function creates a dictionary mapping between parameter IDs in saved_groups
    and a subset of parameter IDs in groups. For each group in saved_groups, it uses
    the indices in the "params" field to select which parameters to keep from the
    corresponding group in groups, then builds a mapping between the original parameter
    IDs and the selected ones.

    Parameters
    ----------
    saved_groups : list[dict[str, Any]]
        A list of dictionaries containing parameter indices in their "params" field.
        These indices indicate which parameters to keep from the original groups.
    groups : list[dict[str, Any]]
        A list of dictionaries containing the original parameter IDs in their "params"
        field.

    Returns
    -------
    dict
        A dictionary mapping parameter IDs from saved_groups to the selected parameter
        IDs from groups.

    Raises
    ------
    ValueError
        If an index in saved_groups["params"] is out of range for the corresponding
        group["params"] in the original groups.

    """
    selected_groups: list[dict[str, Any]] = []
    for s_g, g in zip(saved_groups, groups, strict=True):
        indeces_to_keep = s_g["params"]
        new_params = []
        for i in indeces_to_keep:
            if i < len(g["params"]):
                new_params.append(g["params"][i])
            else:
                msg = (
                    f"index {i} out of range for original"
                    f" params of length {len(g['params'])}"
                )
                raise ValueError(msg)
        selected_groups.append({"params": new_params})
    return dict(
        zip(
            chain.from_iterable(g["params"] for g in saved_groups),
            chain.from_iterable(g["params"] for g in selected_groups),
            strict=True,
        ),
    )


def _cast(
    param: torch.Tensor,
    value: torch.Tensor | dict | Iterable,
    param_id: int,
    param_groups: list[dict[Any, Any]],
    key: Hashable = None,
) -> torch.Tensor | dict | Iterable:
    """Process optimizer state values according to parameter policies.

    This function recursively processes optimizer state values to ensure they are
    correctly cast and transformed according to the parameter's policy. It handles
    tensors, dictionaries, and other iterables by applying appropriate transformations
    to each element or nested structure.

    Parameters
    ----------
    param : torch.Tensor
        The parameter tensor associated with the optimizer state.
    value : torch.Tensor | dict | Iterable
        The value to process, which can be a tensor, dictionary, or other iterable.
    param_id : int
        The ID of the parameter in the optimizer's state.
    param_groups : list[dict[Any, Any]]
        List of parameter groups from the optimizer.
    key : Hashable, optional
        An optional key used when processing dictionary values, by default None.

    Returns
    -------
    torch.Tensor | dict | Iterable
        The processed value with the same structure as the input but potentially
        modified values according to the parameter policy.

    Notes
    -----
    This is a helper function primarily used by `load_partial_state_dict` to handle
    the casting of optimizer state values when loading a partial state dictionary.
    For tensors, it delegates to PyTorch's internal
    `_process_value_according_to_param_policy`. For dictionaries and iterables, it
    recursively processes each element.

    """
    if isinstance(value, torch.Tensor):
        return Optimizer._process_value_according_to_param_policy(  # noqa: SLF001
            param,
            value,
            param_id,
            param_groups,
            key,
        )
    if isinstance(value, dict):
        return {
            k: _cast(
                param,
                v,
                param_id=param_id,
                param_groups=param_groups,
                key=k,
            )
            for k, v in value.items()
        }
    if isinstance(value, Iterable):
        return type(value)(
            _cast(param, v, param_id=param_id, param_groups=param_groups) for v in value  # type: ignore[call-arg, reportCallIssue]
        )
    return None


@torch._disable_dynamo  # noqa: SLF001  # type: ignore[reportPrivateImportUsage]
def load_partial_state_dict(optimizer: Optimizer, state_dict: StateDict) -> None:
    """Load a partial state dictionary into an optimizer.

    This function is a specialized version of the PyTorch optimizer's load_state_dict
    method that supports loading a state dictionary that contains only a subset of the
    optimizer's parameters. This is particularly useful with FSDP models where different
    parameters may be sharded across different processes, and the standard
    load_state_dict method might fail due to shape mismatches or missing parameters.

    The function handles parameter mapping between the saved state dictionary and the
    current optimizer state, preserving the optimizer's existing parameter structure
    while updating state values where available in the state dictionary.

    Parameters
    ----------
    optimizer : Optimizer
        The optimizer instance whose state should be partially updated.
    state_dict : StateDict
        The state dictionary containing the subset of optimizer states to load.
        Must have 'param_groups' and 'state' keys conforming to PyTorch's
        optimizer state dict format.

    Raises
    ------
    ValueError
        If the loaded state dictionary has a different number of parameter groups
        than the optimizer.

    Notes
    -----
    This function is used by the `set_optim_with_partial_state_dict` function to
    support updating optimizer states in complex sharding patterns when working with
    FSDP models. It temporarily replaces the optimizer's native load_state_dict method
    during state updates.

    Unlike the standard PyTorch optimizer's load_state_dict method, this function
    uses `create_partial_id_map` to establish a mapping between parameter IDs in the
    saved state dictionary and the current optimizer parameters, allowing for partial
    updates where only a subset of parameters have state information.

    """
    # shallow copy, to be consistent with module API
    state_dict = state_dict.copy()

    for (
        pre_hook
    ) in optimizer._optimizer_load_state_dict_pre_hooks.values():  # noqa: SLF001
        hook_result = pre_hook(optimizer, state_dict)
        if hook_result is not None:
            state_dict = hook_result

    # Validate the state_dict
    groups = optimizer.param_groups

    # Deepcopy as we write into saved_groups later to update state
    saved_groups = copy.deepcopy(state_dict["param_groups"])

    if len(groups) != len(saved_groups):
        msg = "loaded state dict has a different number of parameter groups"
        raise ValueError(msg)
    id_map = create_partial_id_map(saved_groups, groups)

    # Copy state assigned to params (and cast tensors to appropriate types).
    # State that is not assigned to params is copied as is (needed for
    # backward compatibility).
    state: defaultdict[torch.Tensor, dict[Any, Any]] = defaultdict(dict)
    for k, v in state_dict["state"].items():
        if k in id_map:
            param = id_map[k]
            state[param] = _cast(  # type: ignore[assignment, reportArgumentType]
                param,
                v,
                param_id=k,
                param_groups=state_dict["param_groups"],
            )
        else:
            state[k] = v

    # Update parameter groups, setting their 'params' value
    def update_group(
        group: dict[str, Any],
        new_group: dict[str, Any],
    ) -> dict[str, Any]:
        new_group["params"] = group["params"]
        if "param_names" in group and "param_names" not in new_group:
            new_group["param_names"] = group["param_names"]
        return new_group

    param_groups = list(starmap(update_group, zip(groups, saved_groups, strict=False)))
    optimizer.__setstate__({"param_groups": param_groups})
    optimizer.__dict__["state"].update(state)
    for (
        post_hook
    ) in optimizer._optimizer_load_state_dict_post_hooks.values():  # noqa: SLF001
        post_hook(optimizer)


@contextmanager
def custom_ray_garbage_collector(
    garbage_queue: Queue[ray.ObjectRef],
    list_of_threads: list[threading.Thread],
    timeout: float = 10,
    process_name: str = "CustomRayGarbageCollector",
    *,
    join_at_the_end: bool = True,
) -> Generator[None, Any, None]:
    """Context manager for custom Ray garbage collection.

    This context manager creates a thread that continuously collects and frees
    Ray objects from a given queue. The thread runs in the background and stops
    when the context manager exits.

    Parameters
    ----------
    garbage_queue : Queue[ray.ObjectRef]
        A queue containing Ray object references to be freed.
    list_of_threads : list[threading.Thread]
        A list of threads to append the garbage collection thread to.
    timeout : float, optional
        The time in seconds to wait between garbage collection cycles, by default 10.
    process_name : str, optional
        The name of the process, by default "CustomRayGarbageCollector".
    join_at_the_end : bool, optional
        Whether to join the garbage collection thread at the end, by default
        True.

    Yields
    ------
    None
        The context manager yields control back to the caller.

    Notes
    -----
    The function assumes the existence of `ray_free` and `log` functions/utilities,
    as well as the `ERROR` constant for logging purposes.

    Example
    -------
    >>> garbage_queue = Queue()
    >>> with custom_ray_garbage_collector(garbage_queue):
    >>>     # Your code here
    >>>     pass

    """

    def collect_ray_garbage() -> None:
        while True:
            time.sleep(timeout)
            while not garbage_queue.empty():
                try:
                    obj = garbage_queue.get(block=True, timeout=0.1)
                    ray_free([obj])
                except Exception as e:  # noqa: BLE001
                    log(ERROR, "Error in custom RayGC-%s.", process_name, exc_info=e)

    if not list_of_threads:
        thread = threading.Thread(
            target=collect_ray_garbage,
            name="CustomRayGarbageCollector",
        )
        list_of_threads.append(thread)
        thread.start()

    try:
        yield
    except Exception as e:
        log(
            ERROR,
            "Error in the main scoped caught by RayGC-%s context manager.",
            process_name,
            exc_info=e,
        )
        raise
    finally:
        if join_at_the_end:
            log(DEBUG, "RayGC-%s: Closing.", process_name)
            for thread in list_of_threads:
                thread.join()


def get_random_model_from_config(
    llm_cfg: DictConfig,
) -> torch.nn.Module:
    """Create a randomly initialized model from the provided configuration.

    This function initializes a model based on the provided configuration object. It
    handles the configuration preprocessing by resolving variables, filtering out
    problematic entries, building the tokenizer, and creating the model on CPU.

    Parameters
    ----------
    llm_cfg : DictConfig
        The configuration object used to create the randomly initialized model.

    Returns
    -------
    torch.nn.Module
        A randomly initialized model instance created according to the configuration.

    Example
    -------
    >>> from omegaconf import OmegaConf
    >>> cfg = OmegaConf.create({...})
    >>> model = get_random_model_from_config(cfg)

    Raises
    ------
    TypeError
        If model_config is not a dict or name is not a string.

    """
    # Resolve all interpolation variables as early as possible
    OmegaConf.resolve(llm_cfg)
    OmegaConf.set_struct(llm_cfg, value=False)
    # Deep copy the configuration to prevent any dangerous modification
    internal_cfg = copy.deepcopy(llm_cfg)

    # Filter deprecation warning from torch internal usage
    warnings.filterwarnings(
        action="ignore",
        category=UserWarning,
        message=(
            "torch.distributed.*_base is a private functionand will be deprecated.*"
        ),
    )

    # NOTE: We need to extract a set of global variables from the configuration object
    # to prevent the dataclass creator to crash
    internal_cfg.pop("data_local", None)
    internal_cfg.pop("data_remote", None)
    internal_cfg.pop("global_seed", None)
    # internal_cfg.pop("local_steps", None)
    internal_cfg.pop("name", None)
    # NOTE: This contains OUR global parameters for the ICL tasks that the
    # `make_dataclass_and_log_config` cannot interpret, so we need to pop it
    _icl_tasks_config_dict: dict[str, Any] | None = internal_cfg.pop(
        "icl_tasks_config",
        None,
    )
    _logged_cfg, train_cfg = make_dataclass_and_log_config(
        internal_cfg,
        TrainConfig,
        TRAIN_CONFIG_KEYS,
        transforms="all",
        icl_tasks_required=internal_cfg.get("icl_tasks", None) is not None,
    )
    # Check for incompatibilities between the model and data loaders
    validate_config(train_cfg)
    # Build tokenizer
    tokenizer_name = train_cfg.tokenizer["name"]
    tokenizer_kwargs = train_cfg.tokenizer.get("kwargs", {})
    tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
    # Get model while forcing cpu to prevent any GPU allocation
    model_config = train_cfg.model
    if not isinstance(model_config, dict):
        msg = f"Expected model_config to be a dict, got {type(model_config)}"
        raise TypeError(msg)

    model_config["init_device"] = "cpu"
    name = model_config.pop("name")
    if not isinstance(name, str):
        msg = f"Expected name to be a string, got {type(name)}"
        raise TypeError(msg)

    if not isinstance(model_config, dict):
        msg = f"Expected model_config to be a dict, got {type(model_config)}"
        raise TypeError(msg)

    init_context = process_init_device(model_config, None, None)
    model = build_composer_model(
        name=name,
        tokenizer=tokenizer,
        init_context=init_context,
        master_weights_dtype=model_config.pop("master_weights_dtype", None),
        cfg=model_config,
    )
    # Force model to cpu
    model.cpu()
    return model


def get_random_model_states_ndarrays(
    cfg: DictConfig,
    *,
    return_names: bool = False,
    only_requires_grad: bool = False,
) -> NDArrays | tuple[NDArrays, list[str]]:
    """Retrieve randomly initialized model states as numpy arrays.

    This function creates a randomly initialized model based on the provided
    configuration object (`cfg`). It then extracts the model parameters and, extracts
    its parameters as numpy arrays, and adds zero-initialized arrays to simulate
    optimizer momenta (first and second moments). Optionally returns layer names
    along with the arrays.

    Parameters
    ----------
    cfg : DictConfig
        The configuration object used to create the randomly initialized model.
    return_names : bool, optional
        If True, returns layers names along with the parameter arrays (default is
        False).
    only_requires_grad : bool, optional
        If True, only returns states referring to parameters that require gradients
        (default is False).

    Returns
    -------
    NDArrays | tuple[NDArrays, list[str]]
        If return_names is False, returns only the parameter arrays.
        If return_names is True, returns a tuple containing:
        - The states arrays (including zeros for momenta)
        - The list of layer names

    Example
    -------
    >>> from omegaconf import OmegaConf
    >>> cfg = OmegaConf.create({...})
    >>> params = get_random_model_states_ndarrays(cfg)
    >>> params, names = get_random_model_states_ndarrays(cfg, return_names=True)

    """
    # Resolve the configuration and return the randomly initialized model instantiated
    # on the CPU
    model = get_random_model_from_config(cfg)
    # Get the trainable model parameters as NDArrays
    parameters_ndarrays = [
        val.detach().to("cpu").numpy()
        for val in get_parameters_dict(
            model,
            only_requires_grad=only_requires_grad,
            # NOTE: We need to sort the parameters to ensure that the order is
            # consistent with the names later
            sort_dict=True,
        ).values()
    ]
    # Create two zeroed-out copies of the parameters to simulate the momenta shape,
    # i.e., the addition of model states
    zeros = [np.zeros_like(param) for param in parameters_ndarrays]
    parameters_ndarrays.extend(zeros)
    parameters_ndarrays.extend(zeros)

    if return_names:
        return parameters_ndarrays, get_list_of_parameters_names(
            model=model,
            only_requires_grad=only_requires_grad,
            # NOTE: We need to sort the name to ensure that the order is consistent with
            # the parameters obtained earlier
            sort_list=True,
        )
    return parameters_ndarrays


def get_initial_parameters(
    cfg: BaseConfig,
) -> tuple[NDArrays, tuple[tuple[str, "ModelStateNames"], ...]]:
    """Retrieve the initial parameters for the federated learning server model.

    This function returns the initial parameters of the model using the configuration.
    If a pretrained model path is specified in the configuration (`cfg`), it loads its
    parameters from the specified file. Otherwise, it returns random parameters
    based on the provided large language model (LLM) configuration. Also, it logs
    the shapes and names of the initial parameters for debugging purposes.

    Parameters
    ----------
    cfg : BaseConfig
        The configuration object containing the pretrained model path and LLM config.

    Returns
    -------
    'tuple[NDArrays, tuple[tuple[str, ModelStateNames], ...]]'
        The initial parameters of the model, either loaded from a pretrained model or
        initialized randomly based on the LLM configuration and the corresponding layer
        names and types.

    Raises
    ------
    ValueError: If the number of pretrained parameters does not match the expected

    """
    llm_config = cfg.llm_config
    OmegaConf.resolve(llm_config)
    OmegaConf.set_struct(llm_config, value=False)
    (initial_model_states, layer_names) = cast(
        "tuple[NDArrays, list[str]]",
        get_random_model_states_ndarrays(
            copy.deepcopy(llm_config),
            return_names=True,
            # NOTE: We only want to obtain the trainable parameters here
            only_requires_grad=True,
        ),
    )
    layer_names_and_types: list[tuple[str, ModelStateNames]] = []
    # remove "model." prefix from the keys
    layer_names = [layer_name.replace("model.", "", 1) for layer_name in layer_names]
    for layer_type in [
        ModelStateNames.PARAMETERS,
        ModelStateNames.EXP_AVG,
        ModelStateNames.EXP_AVG_SQ,
    ]:
        layer_names_and_types.extend(
            [(layer_name, layer_type) for layer_name in layer_names],
        )

    if cfg.pretrained_model_path:
        log(
            DEBUG,
            "FL server is loading pretrained model from %s",
            cfg.pretrained_model_path,
        )
        pretrained_model_states = load_model_parameters_from_file(
            Path(cfg.pretrained_model_path),
        )
        if len(pretrained_model_states) != len(initial_model_states):
            msg = (
                f"Expected {len(initial_model_states)} parameters, "
                f"but got {len(pretrained_model_states)}"
            )
            raise ValueError(msg)
        return pretrained_model_states, tuple(layer_names_and_types)

    return initial_model_states, tuple(layer_names_and_types)


@ray.remote
def batch_convert_and_put_float16(array: np.ndarray) -> ray.ObjectRef:
    """Convert a batch of arrays to float16 and put them in the Ray object store.

    This function processes multiple arrays in a single Ray task, reducing task
    creation overhead. Each array is converted to float16 precision and stored
    directly in the Ray object store.

    Parameters
    ----------
    array : np.ndarray
        A list of NumPy arrays to be converted and stored.

    Returns
    -------
    ray.ObjectRef
        A list of Ray object references to the converted arrays in the object store.

    """
    return ray.put(array.astype(np.float16))


@ray.remote
def convert_to_float32(arr: NDArray) -> NDArray:
    """Convert an array to float32 precision.

    This function converts a NumPy array to float32 precision. It's designed to be
    used as a remote function in Ray to perform conversions in parallel during
    parameter loading operations.

    Parameters
    ----------
    arr : NDArray
        The array to be converted to float32 precision.

    Returns
    -------
    NDArray
        The input array converted to float32 precision.

    """
    return np.array(arr).astype(np.float32)


def get_parameters_from_state(
    _config: Config,
    trainer: Trainer,
    parameter_names: Iterable[str] | None = None,
) -> NDArrays:
    """Implement how to get parameters.

    Parameters
    ----------
    _config : Config
        The configuration.
    trainer : Trainer
        The trainer.
    parameter_names : list[str], optional
        The list of parameter names, by default None.

    Returns
    -------
    NDArrays
        The parameters.

    """
    model_parameters_dict = get_parameters_dict(trainer.state.model)
    return (
        [
            model_parameters_dict[
                next((rn for rn in model_parameters_dict if name in rn), "None")
            ]
            .detach()
            .to("cpu")
            .numpy()
            for name in parameter_names
        ]
        if parameter_names is not None
        else [val.detach().to("cpu").numpy() for val in model_parameters_dict.values()]
    )


def get_parameters_dict(
    model: torch.nn.Module,
    *,
    only_requires_grad: bool = False,
    sort_dict: bool = True,
    no_detach_and_clone: bool = False,
) -> dict[str, torch.nn.Parameter] | dict[str, torch.Tensor]:
    """Get the parameters of a model as an OrderedDict.

    Parameters
    ----------
    model : torch.nn.Module
        The model.
    only_requires_grad : bool, optional
        If True, only returns states referring to parameters that require gradients
        (default is False).
    sort_dict : bool, optional
        Whether to sort the dictionary, by default True.
    no_detach_and_clone : bool, optional
        Whether to detach and clone the parameters, by default False.

    Returns
    -------
    dict[str, torch.nn.Parameter] | dict[str, torch.Tensor]
        The trainable parameters.

    Raises
    ------
    ValueError
        If the model is None.

    """
    # Set the prefixes for the FQNs to remove
    prefixes_to_remove = [
        "model.",
        "module.",
        "_fsdp_wrapped_",
        "_checkpoint_wrapped_",
    ]
    params_dict: dict[str, torch.nn.Parameter] | dict[str, torch.Tensor] = {}
    # NOTE: This function is weird because the encapsulation done to support FSDP and
    # DDP is weird. Since they are both likely to change, we MUST maintain this very
    # well and implement as many checkers as we can.
    if hasattr(model, "model") and type(model.model) is FullyShardedDataParallel:
        if model.model is None:  # type: ignore[reportUnnecessaryComparison]
            error_message = "Model is None"
            raise ValueError(error_message)
        inner_model = model.model
        # NOTE: This doesn't work in the case in use_orig_params is True if the FSDP
        # configuration as the tensors returned are flattened breaking some assumptions
        # of the rest of the codebase
        with FullyShardedDataParallel.summon_full_params(
            inner_model,
            recurse=True,
            writeback=False,
            rank0_only=True,
            offload_to_cpu=True,
            with_grads=False,
        ):
            # NOTE: This parameter dict using the above parameters, i.e., (recurse=True,
            # writeback=False, rank0_only=True, offload_to_cpu=True, with_grads=False,),
            # will be complete only on rank 0. The other ranks will have zero-shaped
            # tensors for those layers that are not "living" in there.
            # NOTE: If the FSDP configuration use the original parameters
            # (use_orig_params=true), then the tensors in rank 0 have the correct
            # original shape. In the other ranks they are flattened anyway.
            # NOTE: On ranks > 0 the dictionary won't be empty. It will contain the
            # parameters that are "living" in that rank and will have zero-shaped
            # tensors for the others.
            params_dict = {
                reduce(
                    lambda n, prefix: n.replace(prefix, ""),
                    prefixes_to_remove,
                    name,
                ): (
                    param.detach().clone() if not no_detach_and_clone else param
                )
                for name, param in inner_model.named_parameters()
                if not only_requires_grad or param.requires_grad
            }
    else:
        params_dict = {
            name: param.detach().clone() if not no_detach_and_clone else param
            for name, param in model.named_parameters()
            if not only_requires_grad or param.requires_grad
        }
    if sort_dict:
        params_dict = dict(sorted(params_dict.items()))
    dist.barrier()
    return params_dict


def match_shapes(
    remote_param: torch.Tensor,
    local_param: torch.nn.Parameter | torch.Tensor,
) -> torch.Tensor:
    """Match the shape of a remote parameter to a local parameter if possible.

    This function attempts to match the shape of `remote_param` to `local_param`.
    If the shapes are already identical, `remote_param` is returned unchanged.
    If the total number of elements in both tensors differs, `remote_param`
    is returned unchanged as their shapes are fundamentally incompatible for reshaping.
    If the number of elements is the same but shapes differ:
    - If `local_param` is 1-dimensional, `remote_param` is flattened.
    - If `local_param` is N-dimensional (N > 1 or N = 0 for scalar), `remote_param`
      is reshaped to `local_param.shape`.

    This is useful in scenarios like FSDP model loading, where remote parameters
    (e.g., from a sharded state_dict, potentially flattened) need to conform to
    the local model's parameter shapes.

    Parameters
    ----------
    remote_param : torch.Tensor
        The tensor whose shape needs to be matched.
    local_param : torch.nn.Parameter | torch.Tensor
        The tensor that serves as the shape reference.

    Returns
    -------
    torch.Tensor
        The `remote_param` tensor, potentially reshaped or flattened to match
        `local_param.shape`, or the original `remote_param` if shapes
        cannot be matched due to differing numbers of elements.

    """
    # 1. If shapes already match, return remote_param unchanged.
    if remote_param.shape == local_param.shape:
        return remote_param

    # 2. If the total number of elements differs, shapes cannot be matched.
    #    Return remote_param unchanged.
    if remote_param.numel() != local_param.numel():
        return remote_param

    # At this point, numel() is the same, but shapes are different.
    # 3. If local_param is 1D, flatten remote_param.
    #    Since numel() is the same, remote_param.flatten() will have the same
    #    shape as local_param.
    if local_param.ndim == 1:
        return remote_param.flatten()

    # 4. If local_param is 0D (scalar) or N-D (N > 1), reshape remote_param.
    #    This is safe because numel() is the same.
    #    local_param.ndim == 0 or local_param.ndim > 1
    return remote_param.view(local_param.shape)


def set_optimizer_state_base(
    optimizer: torch.optim.Optimizer,
    local_param: torch.nn.Parameter | torch.Tensor,
    local_param_name: str,
    step: int,
    momentum_type_dict: dict[str, dict[str, NDArray]],
) -> None:
    """Set the optimizer state for a specific parameter.

    This function updates the optimizer state for a given parameter, setting its step
    count and momentum values. It handles name transformations to find the correct
    parameter in the provided momentum dictionaries, performs shape matching between
    remote and local momentum tensors, and updates the optimizer state accordingly.

    Parameters
    ----------
    optimizer : torch.optim.Optimizer
        The optimizer instance whose state is being updated.
    local_param : torch.nn.Parameter | torch.Tensor
        The local parameter or tensor whose optimizer state should be updated.
    local_param_name : str
        The name of the parameter in the model's state_dict.
    step : int
        The step count to set in the optimizer state.
    momentum_type_dict : dict[str, dict[str, NDArray]]
        A dictionary mapping momentum types (e.g., 'exp_avg', 'exp_avg_sq') to
        dictionaries that map parameter names to their corresponding momentum values.

    Notes
    -----
    The function handles common prefixes like 'model.', 'module.', and '_fsdp_wrapped_'
    that might be present in parameter names. It also manages shape differences between
    local and remote momentum tensors using the match_shapes function.

    """
    if "step" not in optimizer.state[local_param]:
        log(WARNING, "Step not found in optimizer state for %s", local_param_name)
    # Set the step to the current server steps
    optimizer.state[local_param]["step"].data = torch.as_tensor(
        step,
        device=optimizer.state[local_param]["step"].device,
        dtype=optimizer.state[local_param]["step"].dtype,
    )
    # Remove potential DDP-related prefixes and the "module." prefix
    lookup_name = (
        local_param_name.replace("model.", "")
        .replace(
            "module.",
            "",
        )
        .replace(
            "_fsdp_wrapped_",
            "",
        )
    )
    # Check if the parameter is in the momentum dictionary
    for momentum_type, momentum_dict in momentum_type_dict.items():
        # Get the new momentum value
        new_momentum = torch.as_tensor(
            momentum_dict[lookup_name],
            device=local_param.device,
            dtype=local_param.dtype,
        )
        local_tensor: torch.Tensor = optimizer.state[local_param][momentum_type]
        # Reshape the remote tensor to match the local parameter's shape
        new_momentum = match_shapes(new_momentum, local_tensor)
        # Set the momentum value
        optimizer.state[local_param][momentum_type].data = new_momentum


def set_param_base(
    remote_param: torch.Tensor,
    local_param: torch.nn.Parameter | torch.Tensor,
) -> bool:
    """Set a local parameter's data with a remote parameter's value.

    This function handles the base case of setting a parameter's value, taking care of
    shape matching and device placement. It's used by higher-level parameter setting
    functions to handle individual parameter assignments. It also sets the
    `requires_grad` flag of the local parameter to match the remote parameter.

    Parameters
    ----------
    remote_param : torch.Tensor
        The remote parameter tensor containing the values to assign.
    local_param : torch.nn.Parameter | torch.Tensor
        The local parameter or tensor whose data will be updated.

    Returns
    -------
    bool
        True if the assignment was successful or not required (empty local parameter),
        False if the assignment failed due to shape mismatch.

    """
    # Skip the setting if the remote parameter is None or the local parameter is empty.
    # The first case happens if the parameter is not found in the remote model, the
    # second case happens if the parameter is not found in the local model, e.g., the
    # current shard doesn't have the parameter.
    if local_param.shape != torch.Size([0]):
        # Try to match the shapes
        remote_param = match_shapes(remote_param, local_param)
        # Set the parameters only if the shape matches
        if local_param.shape == remote_param.shape:
            local_param.data.copy_(remote_param.to(device=local_param.device))
            # Set the requires_grad flag to match the remote parameter
            local_param.requires_grad = remote_param.requires_grad
        # Return False if the assignment failed
        return local_param.shape == remote_param.shape
    # Return True if the assignment was not required (empty local parameter)
    return True


def set_params_with_state_dict(
    modules_buffer: dict[str, tuple[torch.nn.Module, str, torch.Tensor]],
) -> None:
    """Set parameters for modules using state dictionaries.

    This function efficiently updates parameters across multiple modules by leveraging
    PyTorch's state dictionary functionality. It first reorganizes the input buffer to
    process each module only once, then retrieves each module's state dict, updates
    the parameters, and applies the changes back to the module.

    Parameters
    ----------
    modules_buffer : dict[str, tuple[torch.nn.Module, str, torch.Tensor]]
        A dictionary mapping keys to tuples containing:
        - The module instance to update
        - The parameter name as it appears in the module's state dict
        - The tensor containing the new parameter values

    Notes
    -----
    This function is particularly useful for updating parameters in complex module
    hierarchies, as it handles the state dict management required for proper parameter
    updates. It's designed to work with both regular PyTorch modules and modules with
    special parameter handling like FSDP.

    """
    # Construct the internal mapping that allows us to iterate over each module in the
    # modules_buffer only once
    internal_buffer: dict[torch.nn.Module, dict[str, torch.Tensor]] = defaultdict(dict)
    for (
        internal_local_module,
        remote_param_name,
        remote_param,
    ) in modules_buffer.values():
        internal_buffer[internal_local_module] |= {remote_param_name: remote_param}
    # Iterate over the internal buffer and set the parameters
    for (
        internal_local_module,
        remote_param_dict,
    ) in internal_buffer.items():
        # Get the state dict of the local module
        state = get_model_state_dict(
            internal_local_module,
            options=StateDictOptions(
                full_state_dict=True,
                cpu_offload=True,
                broadcast_from_rank0=True,
            ),
        )
        # Set the parameters
        for local_param_name, local_param in state.items():
            assert isinstance(local_param, torch.Tensor)
            set_param_base(
                remote_param_dict[local_param_name],
                local_param,
            )
        # Set the state dict of the local module
        set_model_state_dict(
            internal_local_module,
            state,
            options=StateDictOptions(
                full_state_dict=True,
                cpu_offload=True,
                broadcast_from_rank0=True,
            ),
        )


def print_all_keys(d: dict, indent: int = 0) -> list[str]:
    """Recursively extract all keys from nested dictionaries.

    This function traverses a dictionary and collects all keys, including those from
    nested dictionaries. It returns a flat list containing string representations of
    all keys found.

    Parameters
    ----------
    d : dict
        The dictionary to extract keys from. Can contain nested dictionaries.
    indent : int, optional
        The current indentation level for formatting, by default 0.
        This parameter is primarily used for recursive calls to track hierarchy.

    Returns
    -------
    list[str]
        A list containing string representations of all keys in the dictionary and
        its nested dictionaries.

    Example
    -------
    >>> nested_dict = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
    >>> print_all_keys(nested_dict)
    ['a', 'b', 'c', 'd', 'e']

    Notes
    -----
    The function ignores the structure of the nested dictionaries and returns
    a flat list of all keys. The original indentation level is not reflected
    in the returned list.

    """
    message: list[str] = []
    for key, value in d.items():
        message.append(str(key))
        if isinstance(value, dict):
            message.extend(print_all_keys(value, indent + 2))
    return message


def set_optim_with_state_dict(
    trainer: Trainer,
    momenta_buffer: dict[str, dict[str, torch.Tensor]],
    step_value: int,
) -> float:
    """Set optimizer states using a full state dictionary.

    This function creates and applies a complete optimizer state dictionary to update
    optimizer states for model parameters. It constructs the state dictionary on rank 0,
    containing momentum values and step information for each parameter, then applies
    this state dictionary to the optimizer using PyTorch's state dictionary
    functionality.

    Parameters
    ----------
    trainer : Trainer
        The trainer object containing the model and optimizer to update.
    momenta_buffer : dict[str, dict[str, torch.Tensor]]
        A nested dictionary mapping parameter names to dictionaries of momentum values.
        The inner dictionaries map momentum types (e.g., 'exp_avg', 'exp_avg_sq') to
        their corresponding tensor values.
    step_value : int
        The step count to set in the optimizer state for each parameter.

    Returns
    -------
    float
        The time in seconds taken to set the optimizer states.

    Notes
    -----
    This function is part of the optimizer state update process and works similarly to
    `set_optim_with_partial_state_dict`, but uses the standard state dictionary loading
    approach rather than the partial loading mechanism. It's typically used for models
    without complex sharding patterns.

    """
    start_time_to_set = time.time_ns()
    state: dict[str, dict[str, dict[str, torch.Tensor]] | list[dict[str, Any]]] = {}
    if dist.get_global_rank() == 0:
        my_state: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
        for param_name, momenta_dict in momenta_buffer.items():
            real_param_name = f"model.{param_name}"
            for momentum_type, momentum_tensor in momenta_dict.items():
                my_state[real_param_name] |= {
                    momentum_type: momentum_tensor,
                }
        for key in my_state:
            my_state[str(key)] |= {
                "step": torch.as_tensor(step_value, dtype=torch.float32),
            }
        list_of_parameter_name = list(my_state.keys())
        local_param_groups: list[dict[str, Any]] = []
        for pg in trainer.state.optimizers[0].param_groups:
            local_param_group_dict = {k: v for k, v in pg.items() if k != "params"}
            local_param_group_dict["params"] = list_of_parameter_name
            local_param_groups.append(local_param_group_dict)
        state["param_groups"] = cast("ListDictValueType", local_param_groups)
        state["state"] = my_state
    set_optimizer_state_dict(
        optim_state_dict=cast("OptimizerStateType", state),
        model=trainer.state.model,
        optimizers=trainer.state.optimizers,
        options=StateDictOptions(
            full_state_dict=True,
            cpu_offload=True,
            strict=False,
        ),
    )
    return (time.time_ns() - start_time_to_set) * 1e-9


def set_optim_with_partial_state_dict(
    trainer: Trainer,
    momenta_buffer: dict[str, dict[str, torch.Tensor]],
    step_value: int,
) -> float:
    """Set optimizer states using a partial state dictionary for FSDP models.

    This function creates and applies a partial optimizer state dictionary to update
    optimizer states for parameters in Fully Sharded Data Parallel (FSDP) models. It's
    designed to handle cases where regular state dictionary loading would fail due to
    complex sharding patterns. The function temporarily replaces the optimizer's
    load_state_dict method with a custom partial loading function to ensure proper
    state updates across all workers.

    Parameters
    ----------
    trainer : Trainer
        The trainer object containing the model and optimizer to update.
    momenta_buffer : dict[str, dict[str, torch.Tensor]]
        A nested dictionary mapping parameter names to dictionaries of momentum values.
        The inner dictionaries map momentum types (e.g., 'exp_avg', 'exp_avg_sq') to
        their corresponding tensor values.
    step_value : int
        The step count to set in the optimizer state for each parameter.

    Returns
    -------
    float
        The time in seconds taken to set the optimizer states.

    Notes
    -----
    This function works with the `set_optimizer_state_fsdp` function to handle complex
    parameter sharding in FSDP models. It uses a custom state dictionary loading
    approach by temporarily replacing the optimizer's load_state_dict method with the
    custom `load_partial_state_dict` function.

    """
    optim = trainer.state.optimizers[0]
    start_time_to_set = time.time_ns()
    state: dict[str, dict[str, dict[str, torch.Tensor]] | list[dict[str, Any]]] = {}
    if dist.get_global_rank() == 0:
        my_state: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
        for param_name, momenta_dict in momenta_buffer.items():
            real_param_name = f"model.{param_name}"
            for momentum_type, momentum_tensor in momenta_dict.items():
                my_state[real_param_name] |= {
                    momentum_type: momentum_tensor,
                }
        for key in my_state:
            my_state[str(key)] |= {
                "step": torch.as_tensor(step_value, dtype=torch.float32),
            }
        list_of_parameter_name = list(my_state.keys())
        local_param_groups: list[dict[str, Any]] = []
        for pg in trainer.state.optimizers[0].param_groups:
            local_param_group_dict = {k: v for k, v in pg.items() if k != "params"}
            local_param_group_dict["params"] = list_of_parameter_name
            local_param_groups.append(local_param_group_dict)
        state["param_groups"] = cast("ListDictValueType", local_param_groups)
        state["state"] = my_state
    old_load_state_dict_fn = copy.deepcopy(optim.load_state_dict)
    optim.load_state_dict = partial(load_partial_state_dict, optim)
    set_optimizer_state_dict(
        optim_state_dict=cast("OptimizerStateType", state),
        model=trainer.state.model,
        optimizers=optim,
        options=StateDictOptions(
            full_state_dict=True,
            cpu_offload=True,
            strict=False,
        ),
    )
    optim.load_state_dict = old_load_state_dict_fn
    return (time.time_ns() - start_time_to_set) * 1e-9


def set_params_fsdp(
    local_module_name: str,
    local_module: FullyShardedDataParallel,
    parameters_dict: OrderedDict[str, torch.Tensor],
) -> None:
    """Set parameters for a FullyShardedDataParallel module.

    This function sets parameters for an FSDP module by handling the complexities of
    sharded parameters. It attempts to directly set parameters and collects any
    parameters that couldn't be set directly into a leftover buffer, which will be
    handled through the state dictionary.

    Parameters
    ----------
    local_module_name : str
        The name of the FSDP module whose parameters are to be set.
    local_module : FullyShardedDataParallel
        The FSDP module instance whose parameters are to be set.
    parameters_dict : OrderedDict[str, torch.Tensor]
        An ordered dictionary mapping parameter names to their tensor values.
        Parameters that are successfully set will be removed from this dictionary.

    """
    # Initialize the leftover buffer for parameters that are sharded across multiple
    # ranks not with full shards
    leftover_buffer: dict[
        str,
        tuple[torch.nn.Module, str, torch.Tensor],
    ] = {}
    # Open the FSDP context manager to summon the parameters
    with FullyShardedDataParallel.summon_full_params(
        local_module,
        writeback=True,
        recurse=False,
    ):
        # Iterate over the named modules of the local module
        for (
            internal_local_module_name,
            internal_local_module,
        ) in local_module.named_modules():
            # Iterate over the named parameters of the current internal local module
            for local_param_name, local_param in internal_local_module.named_parameters(
                recurse=False,
            ):
                # Construct the remote parameter name by removing the "model." and FSDP
                # prefixes and by adding the local module name, the internal local
                # module name, and the local parameter name
                remote_param_name = (
                    (
                        (local_module_name + "." + internal_local_module_name)
                        .replace("model.", "")
                        .replace("_fsdp_wrapped_module.", "")
                        .replace("._fsdp_wrapped_module", "")
                        .replace("._checkpoint_wrapped_module", "")
                    )
                    + "."
                    + local_param_name
                )
                # Set remote parameter name if in the list of remote parameters,
                if remote_param_name in parameters_dict:
                    # Get the remote parameter from the OrderedDict
                    remote_params = parameters_dict.pop(remote_param_name)
                    # Set the parameters
                    success = set_param_base(
                        remote_params,
                        local_param,
                    )
                    # Append to the leftover buffer if the assignment failed
                    if not success:
                        leftover_buffer[remote_param_name] = (
                            internal_local_module,
                            remote_param_name,
                            remote_params,
                        )
                else:
                    log(
                        WARNING,
                        "set_params_fsdp: Parameter %s not found in the list of remote"
                        " parameters and won't be set",
                        remote_param_name,
                    )
    # Set the leftover parameters through the state dict
    set_params_with_state_dict(leftover_buffer)


def set_optimizer_state_base_fsdp(
    optim: torch.optim.Optimizer,
    local_names: tuple[str, str, str],
    buffers: tuple[
        dict[str, torch.nn.Module],
        dict[str, torch.nn.Parameter | torch.Tensor],
        dict[str, dict[str, torch.Tensor]],
    ],
    local_objects: tuple[FullyShardedDataParallel, torch.nn.Parameter | torch.Tensor],
    remote_objects: tuple[dict[str, dict[str, NDArray]], int],
) -> None:
    """Set optimizer state for a specific parameter in an FSDP model.

    This function sets optimizer state (like momentum values) for a specific parameter
    in a Fully Sharded Data Parallel (FSDP) model. It handles the complexity of matching
    remote state values to local parameters, accounting for potential sharding across
    multiple workers. If direct application isn't possible due to shape mismatches,
    it collects parameters into buffers for later processing through alternative means.

    Parameters
    ----------
    optim : torch.optim.Optimizer
        The optimizer instance whose state is being updated.
    local_names : tuple[str, str, str]
        A tuple containing three string identifiers:
        - The local module name
        - The internal local module name
        - The local parameter name
    buffers : tuple[
        dict[str, torch.nn.Module],
        dict[str, torch.nn.Parameter | torch.Tensor],
        dict[str, dict[str, torch.Tensor]],
    ]
        A tuple of three dictionaries to collect parameters that need special handling:
        - A mapping of parameter names to their parent modules
        - A mapping of parameter names to parameter tensors
        - A mapping of parameter names to dictionaries of momentum tensors
    local_objects : tuple[FullyShardedDataParallel, torch.nn.Parameter | torch.Tensor]
        A tuple containing:
        - The local FSDP module
        - The local parameter or tensor whose optimizer state should be updated
    remote_objects : tuple[dict[str, dict[str, NDArray]], int]
        A tuple containing:
        - A dictionary mapping parameter names to dictionaries of momentum values
        - The step count to set in the optimizer state

    Notes
    -----
    The function handles common FSDP name prefixes and transformations to ensure proper
    matching between remote and local parameters. When direct state updates aren't
    possible due to sharding, the function adds relevant information to the provided
    buffers for later processing by higher-level functions.

    """
    # Unpack remote objects
    momentum_dict, step = remote_objects
    # Unpack local objects
    local_module, local_param = local_objects
    # Unpack buffers
    modules_buffer, params_leftovers, momenta_buffer = buffers
    # Unpack local names
    local_module_name, internal_local_module_name, local_param_name = local_names
    # Construct the remote parameter name by removing the "model." and FSDP prefixes and
    # by adding the local module name, the internal local module name, and the local
    # parameter name
    remote_param_name = (
        (
            (local_module_name + "." + internal_local_module_name)
            .replace("model.", "")
            .replace("_fsdp_wrapped_module.", "")
            .replace("._fsdp_wrapped_module", "")
            .replace("._checkpoint_wrapped_module", "")
        )
        + "."
        + local_param_name
    )
    # If the local parameter is empty, skip the setting since this shard is not
    # available in this worker
    if local_param.shape == torch.Size([0]):
        return
    # Check if the parameter is in the momentum dictionary
    if remote_param_name in momentum_dict:
        # Get the remote parameter from the OrderedDict
        remote_opts = momentum_dict.pop(remote_param_name)
        # Set the parameters
        for momentum_type, momentum_state in remote_opts.items():
            # Construct the momentum tensor to set
            remote_opt = torch.as_tensor(momentum_state)
            # Get the local momentum tensor
            local_opt: torch.Tensor = optim.state[local_param][momentum_type]
            # Try to match the shapes of the two tensors
            remote_opt = match_shapes(remote_opt, local_opt)
            # If the shapes don't match it means that it is sharded across multiple
            # workers and we need to use another method to set it
            if remote_opt.shape != local_opt.shape:
                # Fill the buffers to be used by a second method later
                modules_buffer[remote_param_name] = local_module
                params_leftovers[remote_param_name] = local_param
                momenta_buffer[remote_param_name][momentum_type] = remote_opt
            else:
                # Set the momentum value
                local_opt.data.copy_(
                    remote_opt.to(local_opt.device, dtype=local_opt.dtype),
                )
        # Set the step to the current server steps
        optim.state[local_param]["step"].data = torch.as_tensor(
            step,
            device=optim.state[local_param]["step"].device,
            dtype=optim.state[local_param]["step"].dtype,
        )
    else:
        log(
            DEBUG,
            "set_optimizer_state_base_fsdp: couldn't find %s in momentum_dict",
            remote_param_name,
        )


def set_optimizer_state_fsdp(
    optim: torch.optim.Optimizer,
    local_module: FullyShardedDataParallel,
    local_module_name: str,
    momentum_dict: dict[str, dict[str, NDArray]],
    step: int,
) -> tuple[
    dict[str, torch.nn.Module],
    dict[str, torch.nn.Parameter | torch.Tensor],
    dict[str, dict[str, torch.Tensor]],
]:
    """Set optimizer states for parameters in an FSDP module.

    This function updates optimizer states for parameters in a Fully Sharded Data
    Parallel (FSDP) module. It handles the complexity of accessing sharded parameters
    by using FSDP's context managers to temporarily gather full parameters across
    workers. For each parameter, it calls the base setter function to update its
    optimizer state. Parameters that couldn't be directly updated (typically due to
    complex sharding) are collected into buffer dictionaries for later processing
    through alternative methods.

    Parameters
    ----------
    optim : torch.optim.Optimizer
        The optimizer instance whose states are being updated.
    local_module : FullyShardedDataParallel
        The FSDP module instance containing parameters whose optimizer states should be
        updated.
    local_module_name : str
        The name of the FSDP module in the model hierarchy.
    momentum_dict : dict[str, dict[str, NDArray]]
        A nested dictionary mapping parameter names to dictionaries of momentum values.
        The inner dictionaries map momentum types (e.g., 'exp_avg', 'exp_avg_sq') to
        their corresponding values as NumPy arrays.
    step : int
        The step count to set in the optimizer state for each parameter.

    Returns
    -------
    tuple[
        dict[str, torch.nn.Module],
        dict[str, torch.nn.Parameter | torch.Tensor],
        dict[str, dict[str, torch.Tensor]],
    ]
        A tuple containing three dictionaries:
        - A mapping of parameter names to their parent modules
        - A mapping of parameter names to parameter tensors
        - A mapping of parameter names to dictionaries of momentum tensors
        These dictionaries contain information about parameters that couldn't be
        directly updated and need special handling through other methods.

    Notes
    -----
    This function is part of the optimizer state update process for FSDP models and
    works in conjunction with `set_optimizer_state_base_fsdp` and
    `set_optim_with_state_dict`. Parameters that can't be directly updated due to
    complex sharding patterns are collected into buffer dictionaries and later
    processed by `set_optim_with_state_dict`.

    """
    # Initialize the leftover buffer for parameters that are sharded across multiple
    # ranks not with full shards
    modules_buffer: dict[str, torch.nn.Module] = {}
    params_leftovers: dict[str, torch.nn.Parameter | torch.Tensor] = {}
    momenta_buffer: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
    # Open the FSDP context manager to summon the parameters that we need to check to
    # acknowledge the sharding structure
    with FullyShardedDataParallel.summon_full_params(
        local_module,
        writeback=True,
        recurse=False,
    ):
        for (
            internal_local_module_name,
            internal_local_module,
        ) in local_module.named_modules():
            # Iterate over the named parameters of the current internal local module
            for local_param_name, local_param in internal_local_module.named_parameters(
                recurse=False,
            ):
                # Call the base setter acting on a single parameter
                set_optimizer_state_base_fsdp(
                    optim=optim,
                    local_names=(
                        local_module_name,
                        internal_local_module_name,
                        local_param_name,
                    ),
                    buffers=(
                        modules_buffer,
                        params_leftovers,
                        momenta_buffer,
                    ),
                    local_objects=(local_module, local_param),
                    remote_objects=(momentum_dict, step),
                )
    return (
        modules_buffer,
        params_leftovers,
        momenta_buffer,
    )


def collect_fsdp_modules(
    module_name: str,
    module: torch.nn.Module,
    maps: tuple[dict[str, FullyShardedDataParallel], dict[str, torch.nn.Module]],
) -> tuple[dict[str, FullyShardedDataParallel], dict[str, torch.nn.Module]]:
    """Recursively collect FSDP and non-FSDP modules from a PyTorch model hierarchy.

    This function traverses the model hierarchy and populates two dictionaries:
    one for FSDP modules and another for non-FSDP modules. The dictionaries map
    module names to their corresponding module instances.

    Parameters
    ----------
    module_name : str
        The name of the current module being processed.
    module : torch.nn.Module
        The current module instance being processed.
    maps : tuple[dict[str, FullyShardedDataParallel], dict[str, torch.nn.Module]]
        A tuple containing two dictionaries to be populated:
        - The first dictionary maps module names to FSDP module instances
        - The second dictionary maps module names to non-FSDP module instances

    Returns
    -------
    tuple[dict[str, FullyShardedDataParallel], dict[str, torch.nn.Module]]
        The updated tuple of dictionaries containing the collected modules.

    """
    fsdp_modules, non_fsdp_modules = maps
    # Add the current module if it's an FSDP module
    if isinstance(module, FullyShardedDataParallel):
        fsdp_modules[module_name] = module
    if isinstance(module, torch.nn.Module):
        non_fsdp_modules[module_name] = module
    # Recursively explore all children
    for child_name, child_module in module.named_children():
        full_child_name = f"{module_name}.{child_name}" if module_name else child_name
        collect_fsdp_modules(
            full_child_name,
            child_module,
            (fsdp_modules, non_fsdp_modules),
        )
    return (fsdp_modules, non_fsdp_modules)


def is_leaf_fsdp_module(
    module_name: str,
    maps: tuple[dict[str, FullyShardedDataParallel], dict[str, torch.nn.Module]],
) -> bool:
    """Check if the module is a leaf FSDP module.

    A leaf FSDP module is one that doesn't contain any child FSDP modules.
    This function helps identify modules that can be processed directly.

    Parameters
    ----------
    module_name : str
        The name of the module to check.
    maps : tuple[dict[str, FullyShardedDataParallel], dict[str, torch.nn.Module]]
        A tuple containing two dictionaries:
        - The first maps module names to FSDP module instances
        - The second maps module names to non-FSDP module instances

    Returns
    -------
    bool
        True if the module is a leaf FSDP module, False otherwise.

    """
    fsdp_modules, non_fsdp_modules = maps
    # Get the module object
    if module_name in fsdp_modules:
        module = fsdp_modules[module_name]
    else:
        module = non_fsdp_modules[module_name]

    # Check if any child is an FSDP module
    for child_name, child_module in module.named_children():
        full_child_name = f"{module_name}.{child_name}" if module_name else child_name
        if isinstance(child_module, FullyShardedDataParallel):
            return False
        return is_leaf_fsdp_module(full_child_name, maps)
    return True


def set_trainer_trainable_params_dict(
    trainer: Trainer,
    parameters_dict: OrderedDict[str, torch.Tensor],
) -> None:
    """Set trainable parameters of a trainer's model from a parameters dictionary.

    This function handles setting parameters for models with and without FSDP (Fully
    Sharded Data Parallel). For FSDP models, it identifies leaf FSDP modules and sets
    their parameters using specialized functions. For non-FSDP models, it directly
    sets parameters that require gradients.

    Parameters
    ----------
    trainer : Trainer
        The trainer object containing the model whose parameters are to be set.
    parameters_dict : OrderedDict[str, torch.Tensor]
        An ordered dictionary mapping parameter names to their tensor values.

    Raises
    ------
    ParameterShapeMismatchError
        If there's a mismatch between the shapes of remote and local parameters
        in non-FSDP models.

    """
    # Get the model from the trainer
    model = trainer.state.model
    # log(DEBUG, "Setting parameters for trainer model: %s.", model)
    # log(DEBUG, "Parameter dict to set: %s.", parameters_dict.keys())
    # Map to store FSDP modules by their name
    fsdp_modules: dict[str, FullyShardedDataParallel] = {}
    non_fsdp_modules: dict[str, torch.nn.Module] = {}

    # Collect all FSDP modules in the model hierarchy
    if hasattr(model, "model") and isinstance(model.model, FullyShardedDataParallel):
        model = model.model
        # First pass: collect all FSDP modules
        (fsdp_modules, non_fsdp_modules) = collect_fsdp_modules(
            "model",
            model,
            (fsdp_modules, non_fsdp_modules),
        )

    if fsdp_modules:
        # Process each FSDP module that's a leaf module or direct parent of leaf modules
        for local_module_name, local_module in fsdp_modules.items():
            if is_leaf_fsdp_module(local_module_name, (fsdp_modules, non_fsdp_modules)):
                set_params_fsdp(local_module_name, local_module, parameters_dict)
    else:
        # If model is not FSDP, handle the regular case
        for local_param_name, local_param in model.named_parameters():
            # Set the parameters only if they require gradients
            if local_param.requires_grad:
                # Remove potential DDP-related prefixes and the "module." prefix
                lookup_name = local_param_name.replace("model.", "").replace(
                    "module.",
                    "",
                )
                if lookup_name not in parameters_dict:
                    log(
                        WARNING,
                        "Parameter %s not found in the list of remote parameters and"
                        " won't be set.",
                        local_param_name,
                    )
                else:
                    success = set_param_base(
                        parameters_dict[lookup_name],
                        local_param,
                    )
                    # Raise error if the shapes don't match
                    if not success:
                        raise ParameterShapeMismatchError(
                            remote_name=lookup_name,
                            remote_shape=parameters_dict[lookup_name].shape,
                            local_shape=local_param.shape,
                        )
    dist.barrier()


# NOTE: This is unused but it is kept for reference
def set_trainable_params_dict(
    model: torch.nn.Module,
    parameters_dict: OrderedDict[str, torch.Tensor],
) -> None:
    """Set the trainable parameters of a model.

    Parameters
    ----------
    model : torch.nn.Module
        The model.
    parameters_dict : OrderedDict[str, torch.Tensor]
        The dictionary of parameters.

    Raises
    ------
    ValueError
        If the shapes don't match.

    """
    # NOTE: This function is weird because the encapsulation done to support FSDP and
    # DDP is weird. Since they are both likely to change, we MUST maintain this very
    # well and implement as many checkers as we can.
    if hasattr(model, "model") and type(model.model) is FullyShardedDataParallel:
        if model.model is None:  # type: ignore[reportUnnecessaryComparison]
            error_message = "Model is None"
            raise ValueError(error_message)
        inner_model = model.model
        # NOTE: This doesn't work in the case in use_orig_params is True if the FSDP
        # configuration as the tensors returned are flattened breaking some assumptions
        # of the rest of the codebase
        with FullyShardedDataParallel.summon_full_params(
            inner_model,
            recurse=True,
            # Writing back is not compatible with rank 0 only
            writeback=True,
            rank0_only=False,
            # Prevent moving to CPU device
            offload_to_cpu=False,
            with_grads=False,
        ):
            # NOTE: !!! THIS REQUIRES INVESTIGATION AS IT DOESN'T WORK AS EXPECTED !!!
            # NOTE: This parameter dict using the above parameters, i.e.,
            # (recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False,
            # with_grads=False,), won't be complete in any rank if the model i sharded.
            # Each rank will have zero-size tensors for those layers that are not
            # "living" in there and the flattened/unflattened complete tensors for those
            # blocks living there.
            # NOTE: If the FSDP configuration use the original parameters
            # (use_orig_params=true), then the tensors in rank 0 have the correct
            # original shape. In the other ranks they are flattened anyway.
            for name, param in inner_model.named_parameters():
                # Set the parameters only if they require gradients & have non-zero size
                if param.requires_grad:
                    param.data = parameters_dict[name].to(param.device)
    else:
        for name, param in model.named_parameters():
            # Set the parameters only if they require gradients
            if param.requires_grad:
                # NOTE: DDP pre-pends "module." to the name of the parameter
                if name.startswith("module."):
                    param.data = parameters_dict[name.replace("module.", "")].to(
                        param.device,
                    )
                # Single GPU
                else:
                    param.data = parameters_dict[name].to(param.device)
    dist.barrier()


def set_trainer_params_from_ndarrays(  # noqa: PLR0913
    parameters: NDArrays,
    trainer: Trainer,
    key_to_filter: str = "transformer",
    *,
    parameters_names: list[str] | None = None,
    filter_keys: bool = True,
    excluded_layers: list[str] | None = None,
    frozen_layers: list[str] | None = None,
) -> None:
    """Set the parameters of a trainer from a list of NDArrays.

    This function attempts to set the parameters of the trainer's model using
    the provided NDArrays. It first tries to set the parameters assuming they
    are ordered. If this fails due to shape mismatches, it retries with the
    parameters unordered.

    Parameters
    ----------
    parameters : NDArrays
        The list of NDArrays representing the model parameters.
    trainer : Trainer
        The trainer object whose model parameters are to be set.
    key_to_filter : str, optional
        The key to filter the parameters, by default "transformer".
    parameters_names : list[str] | None, optional
        The list of parameter names, by default None.
    filter_keys : bool, optional
        Whether to filter the keys, by default True.
    excluded_layers : list[str] | None, optional
        Layers that should not be set, by default None.
    frozen_layers : list[str] | None, optional
        Layers that should be frozen, by default None.

    Raises
    ------
    ValueError
        If setting the parameters fails due to shape mismatches or other issues.

    """
    excluded_layers = excluded_layers or []

    if parameters_names is None:
        parameters_names = get_list_of_parameters_names(
            trainer.state.model,
        )

    ordered_parameters_names = sorted(parameters_names)
    # Try to set the parameters a s if they are ordered
    try:
        parameters_dict = construct_parameters_dict(
            ordered_parameters_names,
            parameters,
            filter_keys=filter_keys,
            key_to_filter=key_to_filter,
            excluded_layers=excluded_layers,
            frozen_layers=frozen_layers,
        )
        set_trainer_trainable_params_dict(trainer, parameters_dict)
    except ValueError as e:
        if "Shapes don't match" in str(e):
            log(
                ERROR,
                "Error trying to set the parameters as ordered, trying unordered",
                exc_info=e,
                stack_info=True,
            )
            # If the ordered parameters failed, try to set the parameters as unordered
            parameters_dict = construct_parameters_dict(
                parameters_names,
                parameters,
                filter_keys=filter_keys,
                key_to_filter=key_to_filter,
                excluded_layers=excluded_layers,
                frozen_layers=frozen_layers,
            )
            set_trainer_trainable_params_dict(trainer, parameters_dict)
        else:
            raise


def get_wte_parameters_from_trainer(trainer: Trainer) -> NDArray:
    """Get the parameters of the WTE layer of a model from a trainer.

    Parameters
    ----------
    trainer : Trainer
        The trainer object.

    Returns
    -------
    NDArray
        The parameters of the WTE layer.

    Raises
    ------
    ValueError
        If there are no WTE parameters or if the WTE parameters are not unique.

    """
    # Get the parameter names of the model
    model_parameter_names = get_list_of_parameters_names(trainer.state.model)
    # Get the WTE parameters
    wte_parameters_dict = {
        name: param
        for name, param in zip(
            model_parameter_names,
            get_parameters_from_state({}, trainer),
            strict=False,
        )
        if "wte" in name
    }
    # Return the WTE parameters
    wte_parameters = list(wte_parameters_dict.values())
    if len(wte_parameters) <= 0:
        msg = "There are no WTE parameters"
        raise ValueError(msg)
    if len(wte_parameters) != 1:
        msg = "WTE parameters are not unique"
        raise ValueError(msg)
    return wte_parameters[0]


def set_wte_parameters_to_trainer(trainer: Trainer, wte_parameters: NDArray) -> None:
    """Set the parameters of the WTE layer of a model to a trainer."""
    # Get the parameter names of the model
    model_parameter_names = get_list_of_parameters_names(trainer.state.model)
    # Get the only the WTE (weights tokens embeddings) parameters
    model_parameters: list[NDArray] = [
        param if "wte" not in name else wte_parameters
        for name, param in zip(
            model_parameter_names,
            get_parameters_from_state({}, trainer),
            strict=False,
        )
    ]
    # Set the WTE parameters
    set_trainer_params_from_ndarrays(model_parameters, trainer, excluded_layers=[])


def get_list_of_parameters_names(
    model: torch.nn.Module,
    *,
    only_requires_grad: bool = True,
    sort_list: bool = True,
) -> list[str]:
    """Return the list of parameters names.

    The list is filtered by `requires_grad` when `only_requires_grad` is True.

    Parameters
    ----------
    model : torch.nn.Module
        The model.
    only_requires_grad : bool, optional
        Whether to filter the parameters by requires_grad, by default True.
        If False, all parameters (even those that are not trainable, i.e.,
        `requires_grad=False`) are included.
    sort_list : bool, optional
        Whether to sort the list of name alphabetically, by default True.
        If False, the list is not sorted.

    Returns
    -------
    list[str]
        The list of parameters names.

    """
    # Get parameter names, filtering by requires_grad if needed
    param_names = [
        name
        for name, param in model.named_parameters()
        if not only_requires_grad or param.requires_grad
    ]

    # Remove unwanted prefixes
    prefixes_to_remove = ["model.", "module.", "_fsdp_wrapped_", "_checkpoint_wrapped_"]
    clean_names = []

    for name in param_names:
        clean_name = name
        for prefix in prefixes_to_remove:
            clean_name = clean_name.replace(prefix, "")
        clean_names.append(clean_name)

    return sorted(clean_names) if sort_list else clean_names


def construct_parameters_dict(  # noqa: PLR0913
    parameters_names: list[str],
    parameters: NDArrays,
    *,
    filter_keys: bool = True,
    key_to_filter: str = "transformer",
    excluded_layers: list[str],
    frozen_layers: list[str] | None = None,
) -> OrderedDict[str, torch.Tensor]:
    """Construct a dictionary of parameters.

    Parameters
    ----------
    parameters_names : list[str]
        The full list of parameters names for the model under consideration.
    parameters : NDArrays
        The list of model parameters onto which we need to construct the OrderedDict.
    filter_keys : bool, optional
        Whether to filter the keys, by default True.
    key_to_filter : str, optional
        The key to filter, by default "transformer".
    excluded_layers : list[str]
        The personalized layers.
    frozen_layers : list[str] | None
        The frozen layers, by default None.

    Returns
    -------
    OrderedDict[str, torch.Tensor]
        The dictionary of parameters.

    """
    if filter_keys or excluded_layers:
        parameters_names = [
            name
            for name in parameters_names
            if key_to_filter in name
            and not any(layer in name or name in layer for layer in excluded_layers)
        ]
        if excluded_layers:
            log(
                DEBUG,
                "Filtered excluded layers: %s from: %s",
                excluded_layers,
                parameters_names,
            )

    zipped_lists = zip(parameters_names, parameters, strict=True)
    parameters_dict = OrderedDict({k: torch.as_tensor(v) for k, v in zipped_lists})
    # Setting the `requires_grad` attribute for propagating it to the local model
    for name, param in parameters_dict.items():
        param.requires_grad = name not in frozen_layers if frozen_layers else True
    return parameters_dict


def wandb_init(
    wandb_enabled: bool,  # noqa: FBT001
    *args: dict,
    **kwargs: dict,
) -> NoOpContextManager | Any | None:  # noqa: ANN401
    """Initialize wandb if enabled.

    Parameters
    ----------
    wandb_enabled : bool
        Whether wandb is enabled.
    args : dict
        The arguments.
    kwargs : dict
        The keyword arguments.

    Returns
    -------
    NoOpContextManager | Any | None
        The wandb context

    Raises
    ------
    ValueError
        If the name is not a string.

    """
    if wandb_enabled:
        # Add server suffix to the name of the run
        name = kwargs.pop("name", "")
        if type(name) is not str:
            error = f"Name must be a string, not {type(name)}"
            raise ValueError(error)
        name += "_server"
        return wandb.init(*args, **kwargs, name=name)  # type: ignore[arg-type,misc]

    return NoOpContextManager()


def sum_of_squares(arrays: NDArrays) -> float:
    """Compute the sum of squares of a list of arrays.

    Parameters
    ----------
    arrays : NDArrays
        List of arrays to compute the sum of squares of.

    Returns
    -------
    float
        The sum of squares of the list of arrays.

    """
    return sum(np.sum(np.square(arr)) for arr in arrays)


def l2_norm(arrays: NDArrays) -> float:
    """Compute the L2 norm of a list of arrays.

    Parameters
    ----------
    arrays : NDArrays
        List of arrays to compute the L2 norm of.

    Returns
    -------
    float
        The L2 norm of the list of arrays.

    """
    return float(np.sqrt(sum_of_squares(arrays)))


def get_device() -> device_type:
    """Determine which device to use for PyTorch.

    Returns
    -------
        str: device for PyTorch

    """
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        device = "mps"
    return cast("device_type", device)


def get_n_cuda_devices() -> int:
    """Get the number of CUDA devices available.

    Returns
    -------
        int: number of CUDA devices available

    """
    if "cuda" in str(get_device()):
        return torch.cuda.device_count()
    return 0


def get_n_cpu_cores() -> int | None:
    """Get the number of CPU cores available.

    Returns
    -------
        int | None: number of CPU cores available

    """
    try:
        cpus = len(psutil.Process().cpu_affinity())  # type: ignore[reportArgumentType]
    except AttributeError:
        cpus = psutil.cpu_count()
    return cpus


def merge_freq_dicts(
    a: dict[int, int],
    b: dict[int, int],
) -> dict[int, int]:
    """Merge two frequency dictionaries.

    Parameters
    ----------
    a : dict[int, int]
        The first frequency dictionary.
    b : dict[int, int]
        The second frequency dictionary.

    Returns
    -------
    dict[int, int]
        The merged frequency dictionary.

    """
    return a | {k: a.get(k, 0) + v for k, v in b.items()}


def get_unigram_probabilities_tensor(
    stream_freq_dict: dict[int, int],
) -> torch.Tensor:
    """Get the unigram probabilities tensor.

    Parameters
    ----------
    stream_freq_dict : dict[int, int]
        The frequency dictionary.

    Returns
    -------
    torch.Tensor
        The unigram probabilities tensor.

    """
    total_tokens = float(sum(v for v in stream_freq_dict.values()))
    probabilities = {k: v / total_tokens for k, v in stream_freq_dict.items()}
    # Get the max token id
    max_token_id = max(stream_freq_dict.keys())
    # Convert to dense tensor
    probabilities_tensor = torch.zeros(max_token_id + 1)
    for k, v in probabilities.items():
        probabilities_tensor[k] = v
    return probabilities_tensor


def is_literal_for_ast(s: str) -> bool:
    """Check if the given str can be evaluate as a literal.

    Parameters
    ----------
    s : str
        The string to check.

    Returns
    -------
    bool
        Whether the string can be evaluated as a literal

    """
    try:
        ast.literal_eval(s)
    except ValueError:
        return False
    return True
