"""Utility functions for model parameter handling and configurations in FL.

This module provides various utility functions to manage model parameters, configure
training settings, and handle data loading for federated learning. It includes functions
to set trainer timestamps, print trainable parameters, retrieve model parameters,
randomize and personalize model layers, and more.

Functions
---------
- set_trainer_timestamp(trainer: Trainer, timestamp: int) -> None
    Set the trainer's timestamp to a specific batch value.


- get_initial_parameters(cfg: BaseConfig) -> Parameters
    Retrieve the initial parameters for the federated learning server model.

- get_random_model_states_ndarrays(
        cfg: DictConfig,
        *,
        return_names: bool = False,
        only_requires_grad: bool = False,
    ) -> NDArrays | tuple[NDArrays, list[str]]
    Retrieve randomly initialized model parameters as numpy arrays.

- randomize_layers(
        parameters: NDArrays,
        dummy_config: DictConfig,
        names: list[str],
        random_layers: list[str],
        truly_random_init: bool,
        cid: int = 0,
        server_round: int = 1
    ) -> None
    Randomize specified layers of the model parameters.

- personalize_layers(
        parameters: NDArrays,
        initial_trainer_parameters: NDArrays,
        personalized_layers: list[str],
        names: list[str],
        unfrozen_names: list[str]
    ) -> None
    Personalize specified layers of the model parameters.

Imports
-------
- copy
- random
- warnings
- torch
- logging.DEBUG
- typing.cast
- pathlib.Path
- composer.Trainer
- composer.utils.reproducibility
- flwr.common.logger.log
- flwr.common.typing.NDArrays
- flwr.common.parameters_to_ndarrays
- flwr.common.ndarrays_to_parameters
- flwr.common.Parameters
- llmfoundry.utils.builders.build_tokenizer
- llmfoundry.utils.builders.build_composer_model
- llmfoundry.utils.config_utils.process_init_device
- llmfoundry.utils.config_utils.make_dataclass_and_log_config
- llmfoundry.utils.config_utils.TRAIN_CONFIG_KEYS
- llmfoundry.utils.config_utils.TrainConfig
- llmfoundry.command_utils.train.validate_config
- omegaconf.DictConfig
- omegaconf.OmegaConf
- repo.conf.base_schema.BaseConfig
- repo.utils.get_list_of_parameters_names
- repo.utils.get_parameters_dict
- repo.utils.load_model_parameters_from_file

Example Usage
-------------
>>> from omegaconf import OmegaConf
>>> from composer import Trainer
>>> cfg = OmegaConf.create({...})
>>> trainer = Trainer(...)
>>> set_trainer_timestamp(trainer, timestamp=100)
>>> initial_params = get_initial_parameters(cfg)
>>> raw_params = get_random_model_states_ndarrays(cfg)
>>> randomize_layers(
>>>     parameters, dummy_config, names, random_layers,
>>>     truly_random_init=True, cid=1, server_round=2,
>>> )
>>> personalize_layers(
>>>     parameters, initial_trainer_parameters,
>>>     personalized_layers, names, unfrozen_names,
>>> )
"""

# ruff: noqa: ERA001
import atexit
import copy
import gc
import random
import time
from collections import OrderedDict, defaultdict
from dataclasses import asdict
from logging import DEBUG
from typing import TYPE_CHECKING, cast

import numpy as np
import streaming
from composer import Trainer
from composer.optim import ADOPT, QHADOPT, DecoupledAdamW
from composer.trainer.trainer import _clear_incomplete_train_states  # noqa: PLC2701
from composer.utils import dist, reproducibility
from composer.utils.iter_helpers import ensure_tuple
from flwr.common.logger import log
from flwr.common.typing import NDArrays, Scalar
from omegaconf import DictConfig
from streaming.base.shared.memory import SharedMemory, shared_memory_list
from torch.distributed.checkpoint.state_dict import (
    StateDictOptions,
    _init_optim_state,  # noqa: PLC2701
    get_optimizer_state_dict,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel

from repo.clients.configs import FitConfig
from repo.clients.llm_config_functions import set_client_checkpoints_path
from repo.conf.base_schema import BaseConfig
from repo.constants import CLIENT_STATE_ACCUMULATOR
from repo.shm.utils import compress_with_strict
from repo.utils import (
    ClientState,
    ModelStateNames,
    collect_fsdp_modules,
    get_initial_parameters,
    get_list_of_parameters_names,
    get_parameters_from_state,
    is_leaf_fsdp_module,
    is_literal_for_ast,
    set_optim_with_partial_state_dict,
    set_optimizer_state_base,
    set_optimizer_state_fsdp,
    sum_of_squares,
)

if TYPE_CHECKING:
    import torch

    from repo.conf.base_schema import BaseConfig


def modify_aggregation_mask_with_frozen_layers(
    client_config: FitConfig,
) -> None:
    """Modify the aggregation mask to include frozen layers.

    This function modifies the aggregation mask in place to include frozen layers
    specified in the client configuration. It updates the mask to ensure that the
    frozen layers are included in the aggregation process.

    Parameters
    ----------
    client_config : FitConfig
        The client configuration containing the aggregation mask and frozen layers.

    """
    if client_config.frozen_layers is not None:
        # NOTE: We need to modify the aggregation mask in place
        # to exclude frozen layers, so they won't be sent over
        tuple_of_bools, layer_names, layer_types = client_config.aggregation_mask
        assert tuple_of_bools is not None, "Bools in the aggregation mask are None"
        list_of_bools = list(tuple_of_bools)
        # List of indices of the True elements in the list_of_bools
        true_indices = [i for i, x in enumerate(list_of_bools) if x]
        trimmed_layer_names: list[str] = []
        trimmed_layer_types: list[ModelStateNames] = []
        for i, (layer_name, layer_type) in enumerate(
            zip(layer_names, layer_types, strict=True),
        ):
            if layer_name in client_config.frozen_layers:
                # Set the corresponding index in the aggregation mask to False
                list_of_bools[true_indices[i]] = False
            else:
                trimmed_layer_names.append(layer_name)
                trimmed_layer_types.append(layer_type)
        client_config.aggregation_mask = (
            tuple(list_of_bools),
            trimmed_layer_names,
            trimmed_layer_types,
        )


def get_client_state_struct(fit_config: FitConfig) -> ClientState:
    """Retrieve the state of a specific client.

    Parameters
    ----------
    fit_config : FitConfig
        The configuration containing the client states and server steps.

    Returns
    -------
    ClientState: The state structure of the client.

    Raises
    ------
    ValueError: If the client states or server steps are not provided in the
        configuration

    """
    # Retrieve the clients' states
    if fit_config.client_states is None:
        msg = "Client states must be provided."
        raise ValueError(msg)
    if fit_config.server_steps_cumulative is None:
        msg = "Server steps must be provided."
        raise ValueError(msg)

    # Extract current client's state
    return fit_config.client_states[str(fit_config.cid)]


def set_initial_config_from_fit_config(
    fit_config: FitConfig,
    llm_config: DictConfig,
) -> tuple[bool, bool, int | None]:
    """Set the initial configuration for a client based on the fit configuration.

    This function configures the local model based on the provided fit configuration,
    including setting the loading path, handling checkpoints, and adjusting model
    parameters.

    Parameters
    ----------
    fit_config : FitConfig
        The configuration containing the client states and server steps.
    llm_config : DictConfig
        The configuration dictionary for the local model.

    Returns
    -------
        tuple[bool, bool, int | None]
            A tuple containing:
            - A boolean indicating whether to skip the iteration.
            - A boolean indicating whether a checkpoint exists.
            - The cumulative server steps.

    Raises
    ------
    ValueError: If the server steps cumulative is None and a checkpoint reset is
    requested.

    """
    # Get the number of local steps done by the current client
    # num_batches_trained = int(str(llm_config["local_steps"]).replace("ba", ""))
    num_batches_trained = fit_config.n_local_steps

    # Set the loading path
    server_steps_cumulative = fit_config.server_steps_cumulative

    skip_iteration = False
    checkpoint_exists = False
    if not fit_config.reset_checkpoint:
        if server_steps_cumulative is None:
            msg = "Server steps cumulative is None and we want to reset a checkpoint."
            raise ValueError(msg)

        skip_iteration, checkpoint_exists = set_client_checkpoints_path(
            llm_config,
            fit_config.cid,
            server_steps_cumulative + num_batches_trained,
        )
    llm_config.load_ignore_keys = ["*scheduler*"]  # type: ignore[union-attr]
    if fit_config.reset_optimizer:
        # Ignoring the optimizer state if loading a checkpoint
        llm_config.load_ignore_keys += ["*optim*"]  # type: ignore[union-attr]
        # Ignoring the optimizer state when saving a checkpoint
        llm_config.save_ignore_keys = ["*optim*"]  # type: ignore[union-attr]

    if fit_config.reset_dataset_state:
        # Ignoring the dataset state if loading a checkpoint
        llm_config.load_ignore_keys += ["*dataset_state*"]

    # NOTE: The following, when re-loading from a checkpoint, returns a weird error
    # if not skip_iteration:
    #     # Ignoring loading the model as we need to set it from the server
    #     cfg.load_ignore_keys += ["*model*"]
    # Extract configs to build the trainer
    if (
        "vocab_size" in llm_config.model
        and (vocab_size := fit_config.resize_vocab) is not None
    ):
        llm_config.model.vocab_size = vocab_size
    return (
        skip_iteration,
        checkpoint_exists,
        server_steps_cumulative,
    )


def set_optimizer_states(
    momentum_vectors_set: tuple[NDArrays, list[str], list[str]],
    trainer: Trainer,
    fit_config: FitConfig,
) -> float:
    """Set optimizer states with momentum vectors for model training.

    This function sets the optimizer states (momentum vectors) for a trainer's model.
    It handles both FSDP (Fully Sharded Data Parallel) and non-FSDP model architectures.
    For FSDP modules, it creates helper state dictionaries and sets the optimizer state
    for each leaf FSDP module. For non-FSDP modules, it loops over each parameter and
    sets the optimizer state directly.

    Parameters
    ----------
    momentum_vectors_set : tuple[NDArrays, list[str], list[str]]
        A tuple containing:
        - NDArrays of momentum values
        - List of momentum parameter names
        - List of momentum types (e.g., 'exp_avg', 'exp_avg_sq')
    trainer : Trainer
        The trainer object containing the model and optimizer.
    fit_config : FitConfig
        The configuration containing server steps information.

    Returns
    -------
    float
        The time in seconds taken to set the optimizer state.

    Raises
    ------
        AssertionError
            If server_steps_cumulative is None or the optimizer type is not supported.

    """
    log(DEBUG, "Setting optimizer state")
    assert fit_config.server_steps_cumulative is not None, (
        "Server steps cumulative is None"
    )
    dist.barrier()
    astart_time = time.time_ns()

    # Get the model from the trainer
    # Map to store FSDP modules by their name
    fsdp_modules: dict[str, FullyShardedDataParallel] = {}
    non_fsdp_modules: dict[str, torch.nn.Module] = {}

    # Collect all FSDP modules in the model hierarchy
    if hasattr(trainer.state.model, "model") and isinstance(
        trainer.state.model.model,
        FullyShardedDataParallel,
    ):
        # First pass: collect all FSDP modules
        (fsdp_modules, non_fsdp_modules) = collect_fsdp_modules(
            "model",
            trainer.state.model.model,
            (fsdp_modules, non_fsdp_modules),
        )

    # Get the optimizer from the trainer object
    optimizer: torch.optim.Optimizer = ensure_tuple(trainer.state.optimizers)[0]
    # log(DEBUG, "Optimizer is %s", optimizer)
    assert isinstance(
        optimizer,
        ADOPT | DecoupledAdamW | QHADOPT,
    ), "The optimizer is not supported for setting the states."
    # NOTE: It seems necessary to leave this call. The optimizer results incomplete with
    # partial or missing states otherwise. It does not seem to represent a big time
    # overhead as far as we have experimented. The log is left as monitor.
    _clear_incomplete_train_states(trainer.state)
    _init_optim_state(optimizer)
    assert len(optimizer.param_groups) == 1, "Only one parameter group is supported."
    # Handle the case where the model is wrapped in FSDP
    state_dict: dict[str, dict[str, np.ndarray]] = cast(
        "dict[str, dict[str, np.ndarray]]",
        defaultdict(lambda: defaultdict(dict)),
    )
    if fsdp_modules:
        # Get the leaf FSDP modules from the model
        leaf_fsdp_modules = {
            module_name: module
            for module_name, module in fsdp_modules.items()
            if is_leaf_fsdp_module(module_name, (fsdp_modules, non_fsdp_modules))
        }
        # Create a helper optimizer state dict for NDArrays

        for momentum, momentum_name, momentum_type in zip(
            *momentum_vectors_set,
            strict=True,
        ):
            state_dict[momentum_name][momentum_type] = momentum
        # Set the optimizer state for the FSDP modules. Add to the buffers the
        # information regarding the sharded states
        modules_buffer: dict[str, torch.nn.Module] = {}
        params_leftovers: dict[str, torch.nn.Parameter | torch.Tensor] = {}
        momenta_buffer: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
        # Initialize the delta for this momentum type
        for local_module_name, local_module in leaf_fsdp_modules.items():
            (c_modules_buffer, c_params_leftovers, c_momenta_buffer) = (
                set_optimizer_state_fsdp(
                    optim=optimizer,
                    local_module=local_module,
                    local_module_name=local_module_name,
                    momentum_dict=state_dict,
                    step=fit_config.server_steps_cumulative,
                )
            )
            modules_buffer.update(c_modules_buffer)
            params_leftovers.update(c_params_leftovers)
            for c_momentum_param_name, c_momentum_dict in c_momenta_buffer.items():
                momenta_buffer[c_momentum_param_name].update(c_momentum_dict)
        # Set the sharded states using the state dict
        set_optim_with_partial_state_dict(
            trainer=trainer,
            step_value=fit_config.server_steps_cumulative,
            momenta_buffer=momenta_buffer,
        )
    else:
        # Create a helper optimizer state dict for NDArrays
        for momentum, momentum_name, momentum_type in zip(
            *momentum_vectors_set,
            strict=True,
        ):
            state_dict[momentum_type][momentum_name] = momentum

        for local_param_name, local_param in trainer.state.model.named_parameters():
            # Set the optimizer state for the current parameter
            set_optimizer_state_base(
                optimizer=optimizer,
                local_param=local_param,
                local_param_name=local_param_name,
                step=fit_config.server_steps_cumulative,
                momentum_type_dict=state_dict,
            )

    dist.barrier()
    time_to_set = (time.time_ns() - astart_time) * 1e-9
    log(DEBUG, "Optimizer state set in %s seconds", time_to_set)
    return time_to_set


def get_ndarrays_and_names_from_payload(
    payload: NDArrays,
    layer_names: list[str],
    layer_types: list[ModelStateNames],
    model_state_type: ModelStateNames,
) -> tuple[NDArrays, list[str]]:
    """Get the NDArrays and names from the payload.

    NDArrays are the model parameters/optimizer states, and names are the layer names.
    The layer types uniquely identify what kind of state a give ndarray is

    Based on the model state type, this function filters the payload and returns

    Parameters
    ----------
    payload : NDArrays
        The payload containing the model parameters.
    layer_names : list[str]
        The names of the layers in the model.
    layer_types : list[ModelStateNames]
        The types of the layers in the model.
    model_state_type : ModelStateNames
        The model state type to filter the payload.

    Returns
    -------
    tuple[NDArrays, list[str]]
        A tuple containing:
        - The filtered NDArrays from the payload.
        - The corresponding layer names.

    """
    selector = [layer_type == model_state_type for layer_type in layer_types]

    return list(compress_with_strict(payload, selector, strict=True)), list(
        compress_with_strict(layer_names, selector, strict=True),
    )


def manipulate_pre_training_ndarrays(
    payload: NDArrays,
    trainer: Trainer,
    configs: tuple[FitConfig, DictConfig],
    client_state_struct: ClientState,
    train_metrics: dict[str, Scalar],
) -> None:
    """Manipulate the initial parameters before training.

    This function personalizes and randomizes layers based on the fit configuration
    and client state, and checks the parameters before training.

    Parameters
    ----------
    payload : NDArrays
        The initial parameters of the model.
    trainer : Trainer
        The trainer instance used for training.
    configs : tuple[FitConfig, DictConfig]
        A tuple containing the fit configuration and the local model configuration.
    client_state_struct : ClientState
        The state structure of the client.
    train_metrics : dict[str, Scalar]
        A dictionary to store training metrics.

    Raises
    ------
    ValueError
        If the payload and selector lengths do not match.

    """
    fit_config, dummy_config = configs

    # Get parameter names from trainer
    names = get_list_of_parameters_names(model=trainer.state.model)

    assert fit_config.transmission_mask is not None, "Transmission mask is None"

    _, layer_names, layer_types = fit_config.transmission_mask

    if len(layer_types) != len(payload):
        msg = "Payload and selector lengths do not match"
        raise ValueError(msg)

    initial_parameters, _ = get_ndarrays_and_names_from_payload(
        payload,
        layer_names=layer_names,
        layer_types=layer_types,
        model_state_type=ModelStateNames.PARAMETERS,
    )

    first_momentum, first_momentum_names = get_ndarrays_and_names_from_payload(
        payload,
        layer_names=layer_names,
        layer_types=layer_types,
        model_state_type=ModelStateNames.EXP_AVG,
    )

    second_momentum, second_momentum_names = get_ndarrays_and_names_from_payload(
        payload,
        layer_names=layer_names,
        layer_types=layer_types,
        model_state_type=ModelStateNames.EXP_AVG_SQ,
    )

    assert fit_config.server_steps_cumulative is not None
    if len(first_momentum + second_momentum) > 0:
        time_to_set = set_optimizer_states(
            trainer=trainer,
            momentum_vectors_set=(
                first_momentum + second_momentum,
                first_momentum_names + second_momentum_names,
                [ModelStateNames.EXP_AVG.value.lower()] * len(first_momentum_names)
                + [ModelStateNames.EXP_AVG_SQ.value.lower()]
                * len(second_momentum_names),
            ),
            fit_config=fit_config,
        )
        log(
            DEBUG,
            "manipulate_pre_training_ndarrays::set opt state, time_to_set: %s",
            time_to_set,
        )
        train_metrics |= {"client/local_adopt/set_optimizer_state_time": time_to_set}
    dist.barrier()

    # Get list of random layers from config
    random_layers: list[str] = (
        fit_config.random_layers if fit_config.random_layers is not None else []
    )
    # Randomize layers
    if (
        random_layers
        and fit_config.random_init_freq > 0
        and client_state_struct.local_steps_cumulative % fit_config.random_init_freq
        == 0
    ):
        randomize_layers(
            parameters=initial_parameters,
            dummy_config=dummy_config,
            names=names,
            random_layers=random_layers,
            truly_random_init=fit_config.truly_random_init,
            cid=int(fit_config.cid),
            server_round=fit_config.server_round,
        )


def extract_l2_norm_for_model_state(
    train_metrics: dict[str, Scalar],
    masked_baseline_state: NDArrays,
    returned_state: NDArrays,
    metrics_keys: tuple[str, str, str, str],
) -> None:
    """Compute and record the l2 norm/delta of new model states.

    Each model state is provided as a list of numpy arrays (NDArrays) for a single
    model state (e.g. for a specific type of parameter). For each pair of corresponding
    arrays in baseline_state and current_state, the function calculates the L2 norm of
    the difference and stores:
      - The per-layer L2 norm under "{metric_key}/layer/{i}/l2_norm".
      - The aggregated L2 norm under "{metric_key}".
      - The time taken for the computation under "{metric_key}_collection_time".

    Parameters
    ----------
    train_metrics : dict[str, Scalar]
        Dictionary to update with computed metrics.
    masked_baseline_state : NDArrays
        The baseline model state (list of numpy arrays) to compare against.
        The function assumes that this has already been masked to have the same length
        as the current state.
    returned_state : NDArrays
        The current model state (list of numpy arrays) to compare.
    metrics_keys : tuple[str, str, str]
        A tuple containing:
        - layer_wise_key: The key for per-layer metrics.

    Raises
    ------
    ValueError
        If the lengths of masked_baseline_state and returned_state do not match.
        This indicates that the two states are not compatible for comparison.

    """
    if not masked_baseline_state or not returned_state:
        log(DEBUG, "No model state to compute metrics")
        return

    if len(masked_baseline_state) != len(returned_state):
        msg = """Masked baseline state
         and returned state lengths do not match, this is not currently supported."""
        raise ValueError(msg)

    layer_wise_key, pre_key, post_key, delta_key = metrics_keys
    total_delta = 0.0
    for i, (base, curr) in enumerate(
        zip(masked_baseline_state, returned_state, strict=True),
    ):
        diff = curr - base
        # Compute the sum of squares for the difference array.
        layer_delta_sum_of_squares = sum_of_squares([diff])
        train_metrics[f"client/layer/{i}/{layer_wise_key}"] = float(
            np.sqrt(layer_delta_sum_of_squares),
        )
        total_delta += layer_delta_sum_of_squares

    train_metrics[pre_key] = float(
        sum_of_squares(masked_baseline_state),
    )
    train_metrics[post_key] = float(
        sum_of_squares(returned_state),
    )
    train_metrics[delta_key] = float(
        np.sqrt(total_delta),
    )


def post_process_client_result(  # noqa: PLR0913, PLR0917
    train_metrics: dict[str, Scalar],
    client_state_struct: ClientState,
    llm_config: DictConfig,
    trainer: Trainer,
    payload: NDArrays,
    fit_config: FitConfig,
) -> tuple[NDArrays, int]:
    """Compute post-training metrics on the client result.

    This function retrieves the model parameters after training, calculates the number
    of samples trained, updates the client state, and collects various training metrics.

    Parameters
    ----------
    train_metrics : dict[str, Scalar]
        A dictionary to store training metrics.
    client_state_struct : ClientState
        The state structure of the client.
    llm_config : DictConfig
        The configuration dictionary for the local model.
    trainer : Trainer
        The trainer instance used for training.
    payload : NDArrays
        The server payload, to be used for computting metrics.
    fit_config : FitConfig
        The configuration containing the client states and server steps.

    Returns
    -------
    tuple[NDArrays, int]
        A tuple containing:
        - The model parameters after training.
        - The number of samples trained.

    """
    # Retrieve model parameters
    start_time = time.time_ns()

    # Modify the aggregation mask in place in case there are any frozen layers
    modify_aggregation_mask_with_frozen_layers(fit_config)
    assert fit_config.transmission_mask is not None
    (mask_transmission, layer_names_transmission, layer_types_transmission) = (
        fit_config.transmission_mask
    )
    _, layer_names_aggregation, layer_types_aggregation = fit_config.aggregation_mask

    model_parameters = get_parameters_from_state(
        {},
        trainer,
        parameter_names=(
            name
            for name, layer_type in zip(
                layer_names_aggregation,
                layer_types_aggregation,
                strict=True,
            )
            if layer_type == ModelStateNames.PARAMETERS
        ),
    )

    # NOTE: We removed here the parameter checkers for efficiency reasons.
    # Re-add them if needed

    train_metrics |= {
        "client/fit_get_parameters_time": (time.time_ns() - start_time) * 1e-9,
    }
    # Retrieve the global train batch size
    global_train_batch_size = int(llm_config["global_train_batch_size"])
    # Get the number of local steps done by the current client
    # num_batches_trained = int(str(llm_config["local_steps"]).replace("ba", ""))
    num_batches_trained = fit_config.n_local_steps
    client_state_struct.steps_done = num_batches_trained
    # Retrieve number of samples trained
    # NOTE: We assume all the clients train with the same batch size,
    # so we just consider the number of local steps
    # NOTE: Assuming that this is the correct value of local steps
    # for the client to train in this particular round and no

    n_samples_trained: int = num_batches_trained * global_train_batch_size

    client_state_struct.local_steps_cumulative += num_batches_trained

    client_state_struct.local_timestamp = {
        k: v
        for k, v in trainer.state.timestamp.copy().state_dict().items()
        if is_literal_for_ast(repr(v))
    }

    # NOTE: We need to add the optimizer state to the model parameters
    opt_states = get_optimizer_state_from_trainer(
        trainer,
        layer_names=layer_names_aggregation,
        layer_types=layer_types_aggregation,
    )
    # Retrieve training metrics
    train_metrics |= {
        k: v.detach().cpu().item()  # type: ignore[attr-defined]
        for k, v in trainer.state.train_metric_values.items()
    }

    # Only rank 0 collects metrics related to the pseudo gradients
    if dist.get_global_rank() == 0:
        start_time = time.time_ns()
        # NOTE: If either the updates model state
        # or the server state are empty
        # we do not compute metrics
        extract_l2_norm_for_model_state(
            train_metrics,
            masked_baseline_state=get_ndarrays_and_names_from_payload(
                # payload=payload,
                payload=list(
                    compress_with_strict(
                        data=payload,
                        selectors=mask_transmission,
                        strict=True,
                    ),
                ),
                layer_names=layer_names_transmission,
                layer_types=layer_types_transmission,
                model_state_type=ModelStateNames.PARAMETERS,
            )[0],
            returned_state=model_parameters,
            metrics_keys=(
                "l2_norm_of_pseudo_gradient",
                "l2_norm_of_model",
                "client/l2_norm_update",
                "client/l2_norm_pseudo_gradient",
            ),
        )
        for acc_key, acc in opt_states.items():
            extract_l2_norm_for_model_state(
                train_metrics,
                masked_baseline_state=get_ndarrays_and_names_from_payload(
                    # payload=payload,
                    payload=list(
                        compress_with_strict(
                            data=payload,
                            selectors=mask_transmission,
                            strict=True,
                        ),
                    ),
                    layer_names=layer_names_transmission,
                    layer_types=layer_types_transmission,
                    model_state_type=ModelStateNames[acc_key.upper()],
                )[0],
                returned_state=acc,
                metrics_keys=(
                    f"l2_norm_of_pseudo_gradient_{acc_key}",
                    f"client/local_adopt/l2_norm/pre_{acc_key}",
                    f"client/local_adopt/l2_norm/post_{acc_key}",
                    f"client/local_adopt/l2_norm/delta_{acc_key}",
                ),
            )

        train_metrics |= {
            "client/fit_metrics_collection_time": (time.time_ns() - start_time) * 1e-9,
        }
        train_metrics |= {
            CLIENT_STATE_ACCUMULATOR: str(
                {fit_config.cid: asdict(client_state_struct)},
            ),
        }

    for acc in opt_states.values():
        model_parameters.extend(acc)
    log(DEBUG, "Client payload size: %s", len(model_parameters))
    return model_parameters, n_samples_trained


def get_optimizer_state_from_trainer(
    trainer: Trainer,
    layer_names: list[str],
    layer_types: list[ModelStateNames],
) -> dict[str, NDArrays]:
    """Extract optimizer state tensors for specified layers from the trainer.

    This function retrieves the optimizer state for specified layers and types from the
    trainer's state dictionary. It converts the optimizer state tensors (like momentum
    values) to NumPy arrays for each layer and organizes them by state type.

    Parameters
    ----------
    trainer : Trainer
        The trainer object containing the optimizer state.
    layer_names : list[str]
        The names of the layers for which to extract optimizer states.
    layer_types : list[ModelStateNames]
        The types of optimizer states to extract (e.g., EXP_AVG, EXP_AVG_SQ).

    Returns
    -------
    dict[str, NDArrays]
        A dictionary where keys are lowercased model state types (e.g., 'exp_avg') and
        values are lists of NumPy arrays containing the optimizer states.

    Raises
    ------
    ValueError
        If a parameter is not found in the optimizer state.

    """
    # NOTE: for some data parallelism implementations
    # the optimizer state dict may be empty
    # for all workers except rank 0
    # thus it only makes sense to add optimizer
    # states for rank 0
    acc_per_typ: dict[str, NDArrays] = OrderedDict()

    dist.barrier()
    start_time = time.time_ns()
    # Check if the model is wrapped in FSDP
    is_fsdp = False
    if hasattr(trainer.state.model, "model") and isinstance(
        trainer.state.model,
        FullyShardedDataParallel,
    ):
        is_fsdp = True

    # Get the optimizer from the trainer object
    optimizer: torch.optim.Optimizer = ensure_tuple(trainer.state.optimizers)[0]

    # Get the optimizer state dict based on whether FSDP is used
    if is_fsdp:
        optim_state_dict = get_optimizer_state_dict(
            model=trainer.state.model,
            optimizers=optimizer,
            # submodules=None,
            options=StateDictOptions(
                full_state_dict=True,
                cpu_offload=True,
            ),
        )
    else:
        optim_state_dict = trainer.state.get_optim_state_dict()[
            type(optimizer).__qualname__
        ]

    if dist.get_global_rank() == 0:
        state = cast("dict[str, dict[str, torch.Tensor]]", optim_state_dict["state"])

        for name, typ in zip(layer_names, layer_types, strict=True):
            real_name = next((rn for rn in state if name in rn), None)

            if real_name is None:
                msg = f"Parameter {name} not found in optimizer state"
                raise ValueError(msg)

            if (lowercased_typ := typ.value.lower()) in state[real_name]:
                if lowercased_typ not in acc_per_typ:
                    acc_per_typ[lowercased_typ] = []

                acc_per_typ[lowercased_typ].append(
                    state[real_name][lowercased_typ].detach().to("cpu").numpy(),
                )
    dist.barrier()
    log(
        DEBUG,
        "Optimizer state extraction time: %s",
        (time.time_ns() - start_time) * 1e-9,
    )
    return acc_per_typ


def streaming_shms_clean_up() -> None:
    """Clean up leaking and stale shared memories.

    This function performs the following clean-up tasks:
    - Cleans up leaking shared memories from the shared_memory_list.
    - Un-registers the cleanup function from atexit.
    - Clears the shared_memory_list.
    - Cleans up stale shared memory using the streaming library.
    - Runs the garbage collector to free up memory.
    """
    # NOTE: Clean up leaking shared memories
    for shm in shared_memory_list:
        SharedMemory.cleanup(shm)
        atexit.unregister(SharedMemory.cleanup)
    shared_memory_list.clear()
    # Cleaning stale shared memory
    streaming.base.util.clean_stale_shared_memory()  # type: ignore[reportAttributeAccessIssue]
    # Clean-up garbage collector
    gc.collect()


def randomize_layers(  # noqa: PLR0913
    parameters: NDArrays,
    dummy_config: DictConfig,
    names: list[str],
    random_layers: list[str],
    cid: int = 0,
    *,
    server_round: int = 1,
    truly_random_init: bool,
) -> None:
    """Randomize specified layers of the model parameters.

    This function randomizes the specified layers of the model parameters based on the
    provided configuration. It can use a truly random initialization if specified.

    Parameters
    ----------
    parameters : NDArrays
        The list of model parameters to be randomized.
    dummy_config : DictConfig
        The configuration object used to create a dummy model for randomization.
    names : list[str]
        The list of parameter names corresponding to the model parameters.
    random_layers : list[str]
        The list of layer names to be randomized.
    truly_random_init : bool
        If True, uses a truly random initialization for the specified layers.
    cid : int, optional
        The client ID used for seeding the random initialization (default is 0).
    server_round : int, optional
        The current server round used for seeding the random initialization
        (default is 1).


    Example
    -------
    >>> from omegaconf import OmegaConf
    >>> parameters = [...]
    >>> dummy_config = OmegaConf.create({...})
    >>> names = ["layer1.weight", "layer1.bias", "layer2.weight"]
    >>> random_layers = ["layer1.weight"]
    >>> randomize_layers(
    >>>     parameters, dummy_config, names, random_layers,
    >>>     truly_random_init=True, cid=1, server_round=2
    >>> )

    """
    new_dummy_config = copy.deepcopy(dummy_config)
    if truly_random_init:
        new_seed = 51550
        for _ in range(server_round):
            new_seed = random.randint(0, 2**32 - 1) ^ cid  # noqa: S311
        new_dummy_config.global_seed = new_seed
        new_dummy_config.seed = new_seed
        reproducibility.seed_all(new_seed)
        log(DEBUG, f"Randomizing layers with seed {new_seed}")

    # Guarantee it is false for pre-trained models
    new_dummy_config.model.pretrained = False
    tmp_dummy_config: BaseConfig = cast(
        "BaseConfig",
        DictConfig(
            {
                "pretrained_model_path": None,
                "llm_config": new_dummy_config,
            },
        ),
    )
    log(DEBUG, f"Creating random model with this config: {tmp_dummy_config}")

    random_parameters, _ = get_initial_parameters(tmp_dummy_config)

    indices = [names.index(key) for key in random_layers]
    for i in indices:
        np.copyto(parameters[i], random_parameters[i])

    log(DEBUG, f"Randomized layers: {random_layers} with indices: {indices}")
