# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""DES-LOC integration utilities for the FL experiments."""

from __future__ import annotations

import logging
import math
import os
import sys
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from fnmatch import fnmatch
from types import ModuleType
from typing import Any, Literal, TYPE_CHECKING

import torch
from torch import nn
from torch.optim import Optimizer

try:  # Optional dependency; skip if unavailable.
    import wandb
except ImportError:  # pragma: no cover - wandb may not be installed
    wandb = None

try:  # pragma: no cover - optional dependency in some environments
    from torch.distributed.tensor import DTensor
except ImportError:  # pragma: no cover - DTensor is optional
    DTensor = None  # type: ignore[assignment]

from torchtitan.components.optimizer import FTOptimizersContainer

try:  # pragma: no cover - optional GaLore dependency
    from torchtitan.experiments.fl.optimizers.galore_global import (
        CODE_TO_PROJ,
        FULL_PROJ,
        GaLoreGlobal,
        LEFT_PROJ,
        RIGHT_PROJ,
        STD_PROJ,
    )
except Exception:  # pragma: no cover - GaLore optional
    GaLoreGlobal = None  # type: ignore[assignment]
    LEFT_PROJ = "left"
    RIGHT_PROJ = "right"
    FULL_PROJ = "full"
    CODE_TO_PROJ: dict[int, str] = {}
    STD_PROJ = "std"

_MODULE_PROXY = sys.modules.get(__name__)
if _MODULE_PROXY is None:
    _MODULE_PROXY = ModuleType(__name__)
    sys.modules[__name__] = _MODULE_PROXY
_MODULE_PROXY.__dict__.update(globals())


def _flatten_for_svd(tensor: torch.Tensor) -> torch.Tensor:
    if tensor.ndim == 0:
        return tensor.reshape(1, 1)
    if tensor.ndim == 1:
        return tensor.reshape(1, -1)
    return tensor.reshape(tensor.shape[0], -1)


def _singular_values_from_tensor(tensor: torch.Tensor) -> torch.Tensor:
    matrix = _flatten_for_svd(tensor).to(dtype=torch.float32)
    m, n = matrix.shape
    if m >= n:
        gram = matrix.transpose(-1, -2) @ matrix
    else:
        gram = matrix @ matrix.transpose(-1, -2)
    eigenvalues = torch.linalg.eigvalsh(gram)
    eigenvalues = torch.clamp(eigenvalues, min=0.0)
    singular_values = torch.sqrt(eigenvalues)
    return torch.sort(singular_values, descending=True).values


def _stable_rank_from_singular_values(singular_values: torch.Tensor) -> torch.Tensor:
    if singular_values.numel() == 0:
        return torch.tensor(float("nan"), device=singular_values.device, dtype=singular_values.dtype)
    fro_sq = singular_values.square().sum()
    spectral_sq = singular_values.max().square()
    if spectral_sq == 0:
        return torch.tensor(float("nan"), device=singular_values.device, dtype=singular_values.dtype)
    return fro_sq / spectral_sq


def _spectral_gap_from_singular_values(singular_values: torch.Tensor, rank: int | None) -> torch.Tensor:
    if rank is None or rank <= 0:
        return torch.tensor(float("nan"), device=singular_values.device, dtype=singular_values.dtype)
    if singular_values.numel() <= rank:
        return torch.tensor(float("nan"), device=singular_values.device, dtype=singular_values.dtype)
    sigma = torch.sort(singular_values, descending=True).values
    s_r = sigma[rank - 1]
    s_r1 = sigma[rank] if sigma.numel() > rank else torch.tensor(0.0, device=sigma.device, dtype=sigma.dtype)
    return s_r - s_r1


def _powerlaw_alpha_from_singular_values(singular_values: torch.Tensor) -> torch.Tensor:
    if singular_values.numel() < 2:
        return torch.tensor(float("nan"), device=singular_values.device, dtype=singular_values.dtype)
    sigma = torch.sort(singular_values, descending=True).values
    sigma = torch.clamp(sigma, min=1e-12)
    k = torch.arange(1, sigma.numel() + 1, device=sigma.device, dtype=sigma.dtype)
    log_k = torch.log(k)
    log_sigma = torch.log(sigma)
    mean_log_k = log_k.mean()
    mean_log_sigma = log_sigma.mean()
    var_log_k = torch.sum((log_k - mean_log_k) ** 2)
    if var_log_k <= 0:
        return torch.tensor(float("nan"), device=sigma.device, dtype=sigma.dtype)
    cov = torch.sum((log_k - mean_log_k) * (log_sigma - mean_log_sigma))
    slope = cov / var_log_k
    return -slope


def _spectrum_metrics(tensor: torch.Tensor, rank: int | None) -> dict[str, float]:
    singular_values = _singular_values_from_tensor(tensor)
    stable_rank = _stable_rank_from_singular_values(singular_values)
    spectral_gap = _spectral_gap_from_singular_values(singular_values, rank)
    rel_gap = torch.tensor(float("nan"), device=singular_values.device, dtype=singular_values.dtype)
    if rank is not None and rank > 0 and singular_values.numel() >= rank:
        sigma_r = singular_values[rank - 1]
        if torch.isfinite(sigma_r) and sigma_r != 0:
            rel_gap = spectral_gap / sigma_r
    alpha = _powerlaw_alpha_from_singular_values(singular_values)

    metrics: dict[str, float] = {}
    for key, val in {
        "pseudo_grad_stable_rank": stable_rank,
        "pseudo_grad_spectral_gap": spectral_gap,
        "pseudo_grad_relative_gap": rel_gap,
        "pseudo_grad_powerlaw_alpha": alpha,
    }.items():
        if torch.isfinite(val):
            metrics[key] = float(val.item())
    return metrics


def _log_wandb_metrics(metrics: dict[str, float], *, step: int | None = None) -> None:
    """Log DES-LOC metrics to wandb using the provided optimization step."""

    if not metrics or wandb is None:
        return
    try:
        if getattr(wandb, "run", None) is not None:
            if step is None:
                wandb.log(metrics)
            else:
                wandb.log(metrics, step=step)
    except Exception:  # pragma: no cover - defensive guard
        logger.debug("Skipping wandb logging for DES-LOC spectrum metrics due to runtime error.")

if TYPE_CHECKING:
    from collections.abc import Callable, Iterable, Iterator, Sequence

    from torch.distributed.distributed_c10d import Work
    from torch.utils.hooks import RemovableHandle

    from torchtitan.components.ft.manager import FTManager
from torchtitan.experiments.fl.configs.optimizers import (
    DesLocConfig,
    DesLocOuterOptimizerConfig,
    DesLocStreamingConfig,
)

logger = logging.getLogger(__name__)
USE_BUCKETIZATION_ENV = "TORCHFT_USE_BUCKETIZATION"


@dataclass(frozen=True)
class ParameterFragmentConfig:
    """Configuration for synchronizing model parameters via DES-LOC."""

    manager: Any
    model: nn.Module
    sync_every: int
    backup_device: torch.device | None
    pin_memory: bool
    name_prefix: str
    outer_optimizer: DesLocOuterOptimizerConfig | Optimizer | list[
        Optimizer
    ] | None = None
    local_optimizer: Optimizer | None = None
    log_outer_metrics: bool = False
    metrics_logger: Callable[[dict[str, float]], None] | None = None
    checkpoint_outer_optimizer: bool = True
    low_rank_server_update: bool = False
    outer_optimizer_low_rank: bool = False
    low_rank_projector_error_feedback: bool = False
    low_rank_projector_source: Literal["pseudo_grad", "full_rank_grad"] = "pseudo_grad"
    pseudo_grad_top_k: float | None = None


@dataclass(frozen=True)
class OptimizerFragmentConfig:
    """Configuration for synchronizing optimizer state tensors."""

    manager: Any
    model: nn.Module
    optimizer: Optimizer
    state_key: str
    sync_every: int
    backup_device: torch.device | None
    name_prefix: str


@dataclass(frozen=True)
class StreamingOptimizerFragmentConfig:
    """Configuration for streaming optimizer state synchronization."""

    manager: Any
    fragment_id: int
    name_prefix: str
    param_entries: list[tuple[str, nn.Parameter]]
    optimizer: Optimizer
    state_key: str
    sync_every: int
    backup_device: torch.device | None
    pin_memory: bool
    use_bucketization: bool
    bucket_cap_mb: float | None
    should_quantize: bool


@dataclass(frozen=True)
class DesLocControllerConfig:
    """Configuration payload for :class:`DesLocController`."""

    manager: Any
    model: nn.Module
    optimizer: Optimizer
    param_sync_every: int
    optimizer_sync_every: int | list[int] | dict[str, int] | None
    backup_device: torch.device | None
    pin_memory: bool
    name_prefix: str
    quorum_timeout_seconds: int
    outer_optimizer: DesLocOuterOptimizerConfig | Optimizer | None = None
    log_outer_metrics: bool = False
    metrics_logger: Callable[[dict[str, float]], None] | None = None
    checkpoint_outer_optimizer: bool = True
    disable_optimizer_state_sync: bool = False
    low_rank_server_update: bool = False
    outer_optimizer_low_rank: bool = False
    low_rank_projector_error_feedback: bool = False
    low_rank_projector_source: Literal["pseudo_grad", "full_rank_grad"] = "pseudo_grad"
    pseudo_grad_top_k: float | None = None


@dataclass(frozen=True)
class DesLocFTOptimizersConfig:
    """Configuration for constructing :class:`DesLocFTOptimizersContainer`."""

    model_parts: list[nn.Module]
    optimizer_cls: type[torch.optim.Optimizer]
    optimizer_kwargs: dict[str, Any]
    ft_manager: Any
    desloc_config: DesLocConfig
    use_ft_optimizer: bool = True
    param_groups: list[dict[str, Any]] | None = None
    outer_optimizer: (
        DesLocOuterOptimizerConfig | Optimizer | list[Optimizer] | None
    ) = None
    streaming: DesLocStreamingConfig | None = None


def _extract_local_tensor(tensor: torch.Tensor) -> torch.Tensor:
    """Return a detached clone of ``tensor`` on its local device."""
    local = (
        tensor.to_local()
        if DTensor is not None and isinstance(tensor, DTensor)
        else tensor
    )
    return local.detach().clone()


def _copy_into_tensor(param: torch.Tensor, value: torch.Tensor) -> None:
    """Copy ``value`` into ``param`` handling ``DTensor`` transparently."""
    if DTensor is not None and isinstance(param, DTensor):  # pragma: no cover - DTensor
        param.copy_(
            DTensor.from_local(
                value,
                param.device_mesh,
                param.placements,
                shape=param.shape,
            )
        )
    else:
        param.copy_(value)


def _zero_optimizer_grads(optimizer: Optimizer | None) -> None:
    """Zero gradients on the provided optimizer, preferring ``set_to_none=True``."""
    if optimizer is None:
        return
    try:
        optimizer.zero_grad(set_to_none=True)
    except TypeError:  # pragma: no cover - optimizer signature variance
        optimizer.zero_grad()


def _apply_topk_sparsity(tensor: torch.Tensor, topk: float | None) -> torch.Tensor:
    """Apply fractional top-k masking (0 < topk < 1) to ``tensor`` in-place.

    Any value outside ``(0, 1)`` leaves the tensor unchanged. Dtype/device are
    preserved; mask is magnitude-based.
    """

    if topk is None:
        return tensor

    try:
        k_value = float(topk)
    except (TypeError, ValueError):  # pragma: no cover - defensive
        return tensor

    if not (0.0 < k_value < 1.0):
        return tensor

    numel = tensor.numel()
    if numel == 0:
        return tensor

    k = max(1, min(numel, int(math.ceil(k_value * numel))))
    if k >= numel:
        return tensor

    flat = tensor.view(-1)
    _, topk_indices = torch.topk(flat.abs(), k, sorted=False)
    mask = torch.zeros_like(flat)
    mask.scatter_(0, topk_indices, 1)
    flat.mul_(mask)
    return tensor


def _partition_named_parameters(
    model: nn.Module,
    fragments: int,
    *,
    strategy: str = "strided",
    custom_fragments: Sequence[Sequence[str]] | None = None,
) -> list[list[tuple[str, nn.Parameter]]]:
    """Partition model parameters into ``fragments`` buckets."""
    if fragments <= 0:
        msg = "desloc.streaming.fragments must be a positive integer."
        raise ValueError(msg)

    named_params = list(model.named_parameters())
    if not named_params:
        return []

    fragments = min(max(1, fragments), len(named_params))
    if custom_fragments is not None:
        return _partition_from_custom_spec(named_params, fragments, custom_fragments)

    strategy = strategy.lower()
    if strategy == "strided":
        return _partition_strided(named_params, fragments)
    if strategy == "sequential":
        return _partition_sequential(named_params, fragments)
    if strategy == "balanced":
        return _partition_balanced(named_params, fragments)
    msg = f"Unknown DES-LOC streaming fragment strategy '{strategy}'."
    raise ValueError(msg)


_GroupedParams = list[list[tuple[str, nn.Parameter]]]


def _partition_strided(
    named_params: list[tuple[str, nn.Parameter]],
    fragments: int,
) -> _GroupedParams:
    groups = _group_parameters_for_striding(named_params)
    non_layer_groups: _GroupedParams = []
    layer_groups: _GroupedParams = []
    for group in groups:
        name = group[0][0]
        if _extract_layer_index(name) is None:
            non_layer_groups.append(group)
        else:
            layer_groups.append(group)

    layer_fragment_count = max(1, fragments)
    buckets: _GroupedParams = [[]]
    for group in non_layer_groups:
        buckets[0].extend(group)

    layer_buckets: _GroupedParams = [[] for _ in range(layer_fragment_count)]
    for idx, group in enumerate(layer_groups):
        slot = idx % layer_fragment_count
        layer_buckets[slot].extend(group)

    for bucket in layer_buckets:
        if bucket:
            buckets.append(bucket)

    return [bucket for bucket in buckets if bucket]


def _partition_sequential(
    named_params: list[tuple[str, nn.Parameter]],
    fragments: int,
) -> list[list[tuple[str, nn.Parameter]]]:
    bucket_size = math.ceil(len(named_params) / fragments)
    ordered = [
        named_params[idx : idx + bucket_size]
        for idx in range(0, len(named_params), bucket_size)
    ]
    return [bucket for bucket in ordered if bucket]


def _partition_balanced(
    named_params: list[tuple[str, nn.Parameter]],
    fragments: int,
) -> list[list[tuple[str, nn.Parameter]]]:
    if fragments == 1:
        return [named_params]

    buckets: list[list[tuple[int, str, nn.Parameter]]] = [[] for _ in range(fragments)]
    bucket_sizes = [0 for _ in range(fragments)]

    indexed = [(idx, name, param) for idx, (name, param) in enumerate(named_params)]
    indexed.sort(key=lambda item: item[2].numel(), reverse=True)

    for original_idx, name, param in indexed:
        slot = min(range(fragments), key=lambda i: bucket_sizes[i])
        buckets[slot].append((original_idx, name, param))
        bucket_sizes[slot] += int(param.numel())

    ordered: list[list[tuple[str, nn.Parameter]]] = []
    for bucket in buckets:
        if not bucket:
            continue
        bucket.sort(key=lambda item: item[0])
        ordered.append([(name, param) for _, name, param in bucket])

    return ordered


def _partition_from_custom_spec(
    named_params: list[tuple[str, nn.Parameter]],
    fragments: int,
    custom_fragments: Sequence[Sequence[str]],
) -> list[list[tuple[str, nn.Parameter]]]:
    buckets_spec = [tuple(fragment) for fragment in custom_fragments if fragment]
    if not buckets_spec:
        msg = "desloc.streaming.custom_fragments must contain at least one selector."
        raise ValueError(msg)
    if len(buckets_spec) != fragments:
        msg = "desloc.streaming.custom_fragments must match desloc.streaming.fragments."
        raise ValueError(msg)

    param_map = dict(named_params)
    remaining = set(param_map.keys())

    partitions: list[list[tuple[str, nn.Parameter]]] = []
    for bucket_idx, selectors in enumerate(buckets_spec):
        bucket: list[tuple[str, nn.Parameter]] = []
        for selector in selectors:
            matches = [name for name in list(remaining) if fnmatch(name, selector)]
            if not matches:
                msg = (
                    f"DES-LOC custom fragment {bucket_idx} selector '{selector}' "
                    "did not match any parameter."
                )
                raise ValueError(msg)
            for name in sorted(matches):
                bucket.append((name, param_map[name]))
                remaining.remove(name)
        if bucket:
            partitions.append(bucket)

    if remaining:
        unused = ", ".join(sorted(remaining)[:3])
        msg = (
            "DES-LOC custom fragments must cover every parameter; "
            f"remaining parameters include: {unused}..."
        )
        raise ValueError(msg)
    return partitions


def _group_parameters_for_striding(
    named_params: list[tuple[str, nn.Parameter]],
) -> _GroupedParams:
    groups: list[list[tuple[str, nn.Parameter]]] = []
    current_group: list[tuple[str, nn.Parameter]] = []
    current_layer: int | None = None

    for name, param in named_params:
        layer_idx = _extract_layer_index(name)
        if layer_idx is None:
            if current_group:
                groups.append(current_group)
                current_group = []
                current_layer = None
            groups.append([(name, param)])
            continue

        if current_layer is None:
            current_layer = layer_idx
        if layer_idx != current_layer:
            groups.append(current_group)
            current_group = [(name, param)]
            current_layer = layer_idx
        else:
            current_group.append((name, param))

    if current_group:
        groups.append(current_group)

    return groups


def _extract_layer_index(param_name: str) -> int | None:
    token = "layers."
    idx = param_name.find(token)
    if idx == -1:
        return None
    remainder = param_name[idx + len(token) :]
    digits: list[str] = []
    for char in remainder:
        if char.isdigit():
            digits.append(char)
        else:
            break
    if not digits:
        return None
    try:
        return int("".join(digits))
    except ValueError:  # pragma: no cover - defensive
        return None


def _contains_layer_params(partition: list[tuple[str, nn.Parameter]]) -> bool:
    return any(name.startswith("layers.") for name, _ in partition)


def _merge_non_layer_partition(
    partitions: list[list[tuple[str, nn.Parameter]]],
) -> list[list[tuple[str, nn.Parameter]]]:
    non_layer_idx = next(
        (
            idx
            for idx, partition in enumerate(partitions)
            if partition and not _contains_layer_params(partition)
        ),
        None,
    )
    if non_layer_idx is None:
        return partitions

    target_idx = next(
        (
            idx
            for idx, partition in enumerate(partitions)
            if idx != non_layer_idx and _contains_layer_params(partition)
        ),
        None,
    )
    if target_idx is None:
        return partitions

    partitions[target_idx] = partitions[non_layer_idx] + partitions[target_idx]
    del partitions[non_layer_idx]
    return partitions


def _component_key_from_name(param_name: str) -> str:
    if param_name.startswith("layers."):
        parts = param_name.split(".")
        if len(parts) >= 3:
            return ".".join(parts[:3])
        return ".".join(parts[: len(parts)])
    return param_name.split(".")[0]


def _format_fragment_membership(names: Sequence[str], limit: int = 8) -> str:
    """Return a short string describing which tensors belong to a fragment."""
    if not names:
        return "none"
    if len(names) <= limit:
        return ", ".join(names)
    remaining = len(names) - limit
    return f"{', '.join(names[:limit])}, ... (+{remaining} more)"


class _BaseFragment:
    def __init__(self, sync_every: int) -> None:
        if sync_every <= 0:
            message = "sync_every must be a positive integer"
            raise ValueError(message)
        self.sync_every = sync_every
        self._local_step = 0

    def tick(self) -> bool:
        self._local_step += 1
        return self._local_step >= self.sync_every

    def reset(self) -> None:
        self._local_step = 0

    def prepare_sync(self) -> list[Any]:
        raise NotImplementedError

    def perform_sync(self) -> None:
        raise NotImplementedError

    def save_state(self) -> None:
        raise NotImplementedError

    def restore_state(self) -> None:
        raise NotImplementedError


class _ParameterFragment(_BaseFragment):
    """Handles parameter state replication and synchronization."""

    def __init__(self, config: ParameterFragmentConfig) -> None:
        super().__init__(config.sync_every)
        self._manager = config.manager
        self._model = config.model
        self._backup_device = config.backup_device
        self._pin_memory = config.pin_memory
        self._name_prefix = config.name_prefix

        self._param_map = dict(self._model.named_parameters())
        self._original_parameters: dict[str, torch.Tensor] = {}
        self._averaged_parameters: list[tuple[str, torch.Tensor]] = []
        self._pseudo_grad_top_k = config.pseudo_grad_top_k

        outer_spec = config.outer_optimizer
        self._outer_optimizer: Optimizer | None = None
        self._reference_synced = outer_spec is None
        self._reference_pending: list[tuple[str, torch.Tensor]] = []
        self._log_outer_metrics = config.log_outer_metrics
        self._metrics_logger = config.metrics_logger
        self._step_ctx: int | None = None
        self._checkpoint_outer_optimizer = config.checkpoint_outer_optimizer
        if isinstance(outer_spec, Optimizer):
            self._outer_optimizer = outer_spec
        elif isinstance(outer_spec, DesLocOuterOptimizerConfig):
            optimizer_cls = outer_spec.resolve_optimizer_cls()
            params = [p for p in self._model.parameters() if p.requires_grad]
            if not params:
                msg = (
                    "DES-LOC outer optimizer requires at least one trainable parameter."
                )
                raise ValueError(msg)
            self._outer_optimizer = optimizer_cls(params, **outer_spec.kwargs)
        elif outer_spec is not None:
            msg = "outer_optimizer must be an Optimizer, DesLocOuterOptimizerConfig, or None."
            raise TypeError(msg)

        self._local_optimizer = config.local_optimizer
        self._low_rank_enabled = bool(
            config.low_rank_server_update
            and GaLoreGlobal is not None
            and isinstance(self._local_optimizer, GaLoreGlobal)
        )
        self._outer_low_rank_enabled = bool(
            config.outer_optimizer_low_rank and self._outer_optimizer is not None
        )
        self._error_feedback_enabled = bool(
            self._low_rank_enabled and config.low_rank_projector_error_feedback
        )
        self._low_rank_projector_source = config.low_rank_projector_source
        if self._low_rank_projector_source not in ("pseudo_grad", "full_rank_grad"):
            msg = (
                "desloc.low_rank_projector_source must be 'pseudo_grad' or "
                f"'full_rank_grad'; received {self._low_rank_projector_source!r}."
            )
            raise ValueError(msg)
        self._projector_error_feedback: dict[str, torch.Tensor] = {}
        self._pre_sync_parameters: dict[str, torch.Tensor] = {}
        self._averaged_gradients: dict[str, torch.Tensor] = {}

        self._init_backup_storage()
        self.save_state()
        if self._outer_optimizer is not None:
            self._reference_synced = True

    def set_metrics_logger(
        self, logger_fn: Callable[[dict[str, float]], None] | None
    ) -> None:
        self._metrics_logger = logger_fn

    def set_step_context(self, step: int) -> None:
        self._step_ctx = step

    def _init_backup_storage(self) -> None:
        for name, param in self._model.named_parameters():
            local_tensor = _extract_local_tensor(param.data)
            device = (
                self._backup_device
                if self._backup_device is not None
                else local_tensor.device
            )
            backup = torch.empty_like(local_tensor, device=device)
            if (
                self._pin_memory
                and backup.device.type == "cpu"
                and torch.cuda.is_available()
            ):
                backup = backup.pin_memory()
            self._original_parameters[name] = backup

    def save_state(self) -> None:
        with torch.no_grad():
            for name, param in self._model.named_parameters():
                self._original_parameters[name].copy_(
                    _extract_local_tensor(param.data), non_blocking=True
                )

    def restore_state(self) -> None:
        with torch.no_grad():
            for name, param in self._model.named_parameters():
                _copy_into_tensor(param.data, self._original_parameters[name])

    def prepare_sync(self) -> list[Any]:
        if self._outer_optimizer is not None and not self._reference_synced:
            # Ensure backups reflect the current model weights (e.g. after checkpoint load).
            self.save_state()
        self._averaged_parameters.clear()
        work_items: list[Any] = []
        if self._low_rank_enabled:
            self._pre_sync_parameters.clear()
            if self._low_rank_projector_source == "full_rank_grad":
                self._averaged_gradients.clear()

        # Determine if we should use sparse pseudo-gradient communication.
        # This applies when sparsification is enabled and we're in averaging-only mode.
        use_sparse_pseudo_grads = (
            self._pseudo_grad_top_k is not None
            and 0.0 < self._pseudo_grad_top_k < 1.0
            and self._outer_optimizer is None
        )
        self._use_sparse_pseudo_grads = use_sparse_pseudo_grads

        for name, param in self._model.named_parameters():
            avg_param = _extract_local_tensor(param.data)
            if self._low_rank_enabled:
                self._pre_sync_parameters[name] = avg_param.clone()
                if self._low_rank_projector_source == "full_rank_grad":
                    grad = param.grad
                    grad_tensor = (
                        torch.zeros_like(avg_param)
                        if grad is None
                        else _extract_local_tensor(grad)
                    )
                    self._averaged_gradients[name] = grad_tensor
                    work_items.append(self._manager.allreduce(grad_tensor))

            if use_sparse_pseudo_grads:
                # Compute pseudo-gradient and apply sparsity to it (not raw params).
                reference = self._original_parameters[name].to(
                    device=avg_param.device, dtype=avg_param.dtype
                )
                pseudo_grad = reference - avg_param
                _apply_topk_sparsity(pseudo_grad, self._pseudo_grad_top_k)
                work_items.append(self._manager.allreduce(pseudo_grad))
                self._averaged_parameters.append((name, pseudo_grad))
            else:
                work_items.append(self._manager.allreduce(avg_param))
                self._averaged_parameters.append((name, avg_param))

        if self._outer_optimizer is not None and not self._reference_synced:
            self._reference_pending.clear()
            for name, avg_param in self._averaged_parameters:
                param = self._param_map[name]
                if not param.requires_grad:
                    continue
                reference = self._original_parameters[name].to(
                    device=avg_param.device,
                    dtype=avg_param.dtype,
                    copy=True,
                )
                work_items.append(self._manager.allreduce(reference))
                self._reference_pending.append((name, reference))

        return work_items

    def perform_sync(self) -> None:
        if self._outer_optimizer is not None and not self._reference_synced:
            for name, reference in self._reference_pending:
                backup = self._original_parameters[name]
                backup.copy_(reference.to(backup.device, dtype=backup.dtype))
            self._reference_pending.clear()
            self._reference_synced = True

        if self._outer_optimizer is None:
            with torch.no_grad():
                use_sparse = getattr(self, "_use_sparse_pseudo_grads", False)
                for name, avg_tensor in self._averaged_parameters:
                    param = self._param_map[name]
                    if use_sparse:
                        # avg_tensor is the averaged sparse pseudo-gradient.
                        # Reconstruct: new_param = reference - averaged_sparse_pseudo_grad
                        reference = self._original_parameters[name].to(
                            device=avg_tensor.device, dtype=avg_tensor.dtype
                        )
                        new_param = reference - avg_tensor
                        _copy_into_tensor(param.data, new_param)
                    else:
                        # avg_tensor is the averaged parameter directly.
                        _copy_into_tensor(param.data, avg_tensor)
            if self._low_rank_enabled:
                self._update_low_rank_projectors()
                self._pre_sync_parameters.clear()
                self._averaged_gradients.clear()
            return

        pseudo_norm_sq = 0.0
        grads_assigned = False
        with torch.no_grad():
            for name, avg_param in self._averaged_parameters:
                param = self._param_map[name]
                if not param.requires_grad:
                    _copy_into_tensor(param.data, avg_param)
                    continue

                reference_native = self._original_parameters[name].to(
                    device=avg_param.device,
                    dtype=avg_param.dtype,
                )
                averaged_native = avg_param

                # Ensure every replica applies gradients starting from the shared reference.
                _copy_into_tensor(param.data, reference_native)

                grad_native = reference_native - averaged_native
                pseudo_norm_sq += grad_native.pow(2).sum().item()
                param.grad = grad_native
                grads_assigned = True

        if not grads_assigned:
            with torch.no_grad():
                for name, avg_param in self._averaged_parameters:
                    param = self._param_map[name]
                    _copy_into_tensor(param.data, avg_param)
            if self._low_rank_enabled:
                self._update_low_rank_projectors()
                self._pre_sync_parameters.clear()
                self._averaged_gradients.clear()
            return

        self._outer_optimizer.step()
        _zero_optimizer_grads(self._outer_optimizer)

        if self._log_outer_metrics and self._metrics_logger is not None:
            metrics: dict[str, float] = {
                "desloc_outer/pseudo_grad_l2": math.sqrt(max(pseudo_norm_sq, 0.0))
            }
            momentum_norm_sq = 0.0
            has_momentum = False
            if isinstance(self._outer_optimizer, torch.optim.SGD):
                for state in self._outer_optimizer.state.values():
                    buffer = state.get("momentum_buffer")
                    if isinstance(buffer, torch.Tensor):
                        has_momentum = True
                        momentum_norm_sq += buffer.pow(2).sum().item()
            if has_momentum:
                metrics["desloc_outer/momentum_l2"] = math.sqrt(
                    max(momentum_norm_sq, 0.0)
                )
            try:
                self._metrics_logger(metrics)
            except Exception:  # pragma: no cover - diagnostics only
                logger.exception(
                    "DES-LOC failed to log outer optimizer metrics; continuing."
                )

        if self._low_rank_enabled:
            self._update_low_rank_projectors()
            self._pre_sync_parameters.clear()
            self._averaged_gradients.clear()

    def register_state_dict_fn(self) -> None:
        def load_fn(state_dict: dict[str, torch.Tensor]) -> None:
            if state_dict:
                for name, tensor in state_dict.items():
                    if name in self._original_parameters:
                        self._original_parameters[name].copy_(tensor)
            else:
                # Older checkpoints might not have stored the DES-LOC state; fall back to fresh capture.
                self.save_state()
            self._reference_synced = False
            self._reference_pending.clear()

        def save_fn() -> dict[str, torch.Tensor]:
            return self._original_parameters

        self._manager.register_state_dict_fn(
            f"{self._name_prefix}_params",
            load_fn,
            save_fn,
        )

        if self._outer_optimizer is not None and self._checkpoint_outer_optimizer:

            def load_outer(state_dict: dict[str, Any]) -> None:
                self._outer_optimizer.load_state_dict(state_dict)

            def save_outer() -> dict[str, Any]:
                return self._outer_optimizer.state_dict()

            self._manager.register_state_dict_fn(
                f"{self._name_prefix}_outer_optimizer",
                load_outer,
                save_outer,
            )

    def _decode_projection_type(self, meta: dict[str, Any]) -> str:
        value = meta.get("resolved_proj_type") or meta.get("proj_type")
        if isinstance(value, int) and CODE_TO_PROJ:
            return CODE_TO_PROJ.get(value, STD_PROJ)
        if isinstance(value, str):
            return value
        return STD_PROJ

    @staticmethod
    def _canonicalize_projection_tensor(
        tensor: torch.Tensor,
    ) -> tuple[torch.Tensor | None, tuple[int, ...]]:
        shape = tensor.shape
        if tensor.ndim == 0:
            return tensor.reshape(1, 1), shape
        if tensor.ndim == 1:
            return tensor.unsqueeze(0), shape
        if tensor.ndim == 2:
            return tensor, shape
        return None, shape

    def _compute_projector_reconstruction(
        self,
        tensor: torch.Tensor,
        basis: torch.Tensor | list[torch.Tensor],
        proj_type: str,
    ) -> torch.Tensor | None:
        canonical, original_shape = self._canonicalize_projection_tensor(tensor)
        if canonical is None:
            return None

        device = tensor.device
        dtype = tensor.dtype
        if proj_type == FULL_PROJ:
            if not isinstance(basis, list) or len(basis) != 2:
                return None
            left, right = basis
            left = left.to(device=device, dtype=dtype)
            right = right.to(device=device, dtype=dtype)
            if (
                canonical.shape[0] != left.shape[0]
                or canonical.shape[1] != right.shape[1]
            ):
                return None
            low_rank = left.transpose(-1, -2) @ canonical @ right.transpose(-1, -2)
            reconstruction = left @ low_rank @ right
        elif proj_type == LEFT_PROJ:
            if not isinstance(basis, torch.Tensor):
                return None
            left = basis.to(device=device, dtype=dtype)
            if canonical.shape[0] != left.shape[0]:
                return None
            low_rank = left.transpose(-1, -2) @ canonical
            reconstruction = left @ low_rank
        else:  # RIGHT_PROJ and defaults
            if not isinstance(basis, torch.Tensor):
                return None
            right = basis.to(device=device, dtype=dtype)
            if canonical.shape[1] != right.shape[1]:
                return None
            low_rank = canonical @ right.transpose(-1, -2)
            reconstruction = low_rank @ right

        return reconstruction.reshape(original_shape).to(device=device, dtype=dtype)

    def _select_projector_signal(
        self,
        *,
        name: str,
        base_pseudo_grad: torch.Tensor | None,
    ) -> torch.Tensor | None:
        """Return the tensor used to refresh GaLore projector bases.

        Parameters:
            name: Parameter name for lookup in the cached gradients.
            base_pseudo_grad: Pseudo-gradient computed from synced parameters.

        Returns:
            The tensor to feed into the projector SVD update.
        """
        if self._low_rank_projector_source == "full_rank_grad":
            full_rank_grad = self._averaged_gradients.get(name)
            if full_rank_grad is None:
                logger.debug(
                    "DES-LOC projector refresh missing full-rank grad for %s.",
                    name,
                )
                return base_pseudo_grad
            if base_pseudo_grad is not None and full_rank_grad.shape != base_pseudo_grad.shape:
                logger.warning(
                    "DES-LOC projector refresh gradient shape mismatch for %s.",
                    name,
                )
                return base_pseudo_grad
            return full_rank_grad
        return base_pseudo_grad

    def _build_projection_basis(
        self,
        pseudo_grad: torch.Tensor,
        rank: int,
        meta: dict[str, Any],
    ) -> torch.Tensor | list[torch.Tensor] | None:
        if rank <= 0:
            return None

        matrix = pseudo_grad.detach()
        if matrix.ndim == 0:
            matrix = matrix.reshape(1, 1)
        elif matrix.ndim == 1:
            matrix = matrix.unsqueeze(0)

        if not torch.isfinite(matrix).all() or matrix.abs().max().item() == 0:
            matrix = torch.randn_like(matrix)

        reduced_rank = min(rank, *matrix.shape)
        if reduced_rank <= 0:
            return None

        try:
            u, _s, v_h = torch.linalg.svd(matrix, full_matrices=False)
            u = u[:, :reduced_rank].contiguous()
            v_h = v_h[:reduced_rank, :].contiguous()
        except RuntimeError:
            # Fall back to random orthogonal bases when SVD fails.
            u_rand = torch.randn(
                matrix.shape[0], reduced_rank, device=matrix.device, dtype=matrix.dtype
            )
            u, _ = torch.linalg.qr(u_rand, mode="reduced")
            v_h = torch.randn(
                reduced_rank, matrix.shape[1], device=matrix.device, dtype=matrix.dtype
            )

        proj_type = self._decode_projection_type(meta)
        if proj_type == FULL_PROJ:
            return [u, v_h]
        if proj_type == LEFT_PROJ:
            return u
        return v_h

    def _update_low_rank_projectors(self) -> None:
        optimizer = self._local_optimizer
        if not self._low_rank_enabled or optimizer is None:
            return

        active_feedback_names: set[str] = set()
        for name, avg_param in self._averaged_parameters:
            local_snapshot = self._pre_sync_parameters.get(name)
            if local_snapshot is None:
                continue
            param = self._param_map[name]
            state = optimizer.state.get(param)
            if not state:
                continue
            meta = state.get("projector_meta")
            if not isinstance(meta, dict):
                continue
            rank = int(meta.get("rank") or 0)
            if rank <= 0:
                continue
            if self._error_feedback_enabled:
                active_feedback_names.add(name)

            # When using sparse pseudo-gradients, avg_param IS the averaged sparse
            # pseudo-gradient. Otherwise, compute the pseudo-gradient from params.
            use_sparse = getattr(self, "_use_sparse_pseudo_grads", False)
            if use_sparse:
                # avg_param is already the averaged sparse pseudo-gradient.
                base_pseudo_grad = avg_param
            else:
                # Compute pseudo-gradient: local_snapshot - averaged_params
                base_pseudo_grad = local_snapshot - avg_param

            projector_signal = self._select_projector_signal(
                name=name,
                base_pseudo_grad=base_pseudo_grad,
            )
            logger.info(f"Projector signal for {name} selected from {self._low_rank_projector_source}.")
            if projector_signal is None:
                continue
            # Log spectrum stats of the projector signal (pseudo/full grad) for analysis.
            spectrum_metrics = _spectrum_metrics(projector_signal, rank)
            if spectrum_metrics:
                metrics_payload = {f"desloc_outer/{k}/{name}": v for k, v in spectrum_metrics.items()}
                if self._metrics_logger is not None:
                    try:
                        self._metrics_logger(metrics_payload)
                    except Exception:  # pragma: no cover - diagnostics only
                        logger.exception("DES-LOC spectrum metrics logger failed.")
                _log_wandb_metrics(metrics_payload, step=self._step_ctx)
            if self._error_feedback_enabled:
                feedback = self._projector_error_feedback.get(name)
                if feedback is not None:
                    if feedback.shape == projector_signal.shape:
                        projector_signal = projector_signal + feedback.to(
                            device=projector_signal.device,
                            dtype=projector_signal.dtype,
                        )
                    else:
                        self._projector_error_feedback.pop(name, None)
            basis = self._build_projection_basis(
                projector_signal,
                rank,
                meta,
            )
            if basis is not None:
                old_basis = state.get("projector_basis")
                proj_type = self._decode_projection_type(meta)
                rotate_fn = getattr(self._local_optimizer, "rotate_momenta", None)
                if rotate_fn is not None and old_basis is not None:
                    try:
                        rotate_fn(
                            param,
                            old_basis=old_basis,
                            new_basis=basis,
                            proj_type=proj_type,
                        )
                    except Exception:  # pragma: no cover - diagnostics only
                        logger.exception(
                            "DES-LOC momentum rotation failed; continuing without rotation."
                        )
                state["projector_basis"] = basis
                state.pop("_placeholder_projector", None)
                state.pop("_bootstrap_projector", None)
                if self._outer_low_rank_enabled:
                    self._apply_outer_low_rank_basis(param, basis, meta)

                if self._error_feedback_enabled:
                    reconstruction = self._compute_projector_reconstruction(
                        projector_signal,
                        basis,
                        proj_type,
                    )
                    if reconstruction is None or reconstruction.shape != projector_signal.shape:
                        self._projector_error_feedback.pop(name, None)
                    else:
                        residual = (projector_signal - reconstruction).detach().clone()
                        self._projector_error_feedback[name] = residual

        finalize_placeholders = getattr(
            optimizer, "finalize_placeholder_projectors", None
        )
        if callable(finalize_placeholders):
            finalize_placeholders()

        if self._error_feedback_enabled:
            stale = set(self._projector_error_feedback.keys()) - active_feedback_names
            for entry in stale:
                self._projector_error_feedback.pop(entry, None)

    def _apply_outer_low_rank_basis(
        self,
        param: torch.Tensor,
        basis: torch.Tensor | list[torch.Tensor],
        meta: dict[str, Any],
    ) -> None:
        msg = (
            "DES-LOC outer optimizer low-rank support is not implemented yet. "
            "Disable desloc.low_rank_outer_optimizer while this branch is under development."
        )
        raise NotImplementedError(msg)


class _OuterOptimizingParameterFragment(_ParameterFragment):
    """Marker subclass instantiated when an outer optimizer is configured."""


class _OptimizerStateFragment(_BaseFragment):
    """Synchronize a specific optimizer state tensor across replicas."""

    def __init__(self, config: OptimizerFragmentConfig) -> None:
        super().__init__(config.sync_every)
        self._manager = config.manager
        self._model = config.model
        self._optimizer = config.optimizer
        self.state_key = config.state_key
        self._backup_device = config.backup_device
        self._name_prefix = config.name_prefix

        self._param_map = dict(self._model.named_parameters())
        self._original_state_tensors: dict[str, torch.Tensor] = {}
        self._averaged_state_tensors: list[torch.Tensor] = []

        self._init_backup_storage()
        self.save_state()

    def _init_backup_storage(self) -> None:
        for name, param in self._model.named_parameters():
            state = self._optimizer.state.get(param, {})
            tensor = state.get(self.state_key)
            if isinstance(tensor, torch.Tensor):
                device = (
                    self._backup_device
                    if self._backup_device is not None
                    else tensor.device
                )
                self._original_state_tensors[name] = torch.empty_like(
                    tensor, device=device
                )

    def save_state(self) -> None:
        with torch.no_grad():
            for name, backup in self._original_state_tensors.items():
                param = self._param_map[name]
                tensor = self._optimizer.state[param][self.state_key]
                backup.copy_(tensor, non_blocking=True)

    def restore_state(self) -> None:
        with torch.no_grad():
            for name, backup in self._original_state_tensors.items():
                param = self._param_map[name]
                if (
                    param in self._optimizer.state
                    and self.state_key in self._optimizer.state[param]
                ):
                    self._optimizer.state[param][self.state_key].copy_(backup)

    def prepare_sync(self) -> list[Any]:
        self._averaged_state_tensors.clear()
        work_items: list[Any] = []
        for name in self._original_state_tensors:
            param = self._param_map[name]
            state_tensor = self._optimizer.state[param][self.state_key]
            avg_state = state_tensor.detach().clone()
            work_items.append(self._manager.allreduce(avg_state))
            self._averaged_state_tensors.append(avg_state)
        return work_items

    def perform_sync(self) -> None:
        with torch.no_grad():
            for name, averaged in zip(
                self._original_state_tensors.keys(),
                self._averaged_state_tensors,
                strict=True,
            ):
                param = self._param_map[name]
                self._optimizer.state[param][self.state_key].copy_(averaged)

    def register_state_dict_fn(self) -> None:
        def load_fn(state_dict: dict[str, torch.Tensor]) -> None:
            for name, tensor in state_dict.items():
                if name in self._original_state_tensors:
                    self._original_state_tensors[name].copy_(tensor)

        def save_fn() -> dict[str, torch.Tensor]:
            return self._original_state_tensors

        self._manager.register_state_dict_fn(
            f"{self._name_prefix}_state_{self.state_key}",
            load_fn,
            save_fn,
        )


class _StreamingOptimizerStateFragment(_BaseFragment):
    """Streaming-aware optimizer state fragment."""

    bucket_cap_mb: int = 1 * 1024 * 1024 * 1024
    use_bucketization: bool = False

    def __init__(self, config: StreamingOptimizerFragmentConfig) -> None:
        super().__init__(config.sync_every)
        self._manager = config.manager
        self._fragment_id = config.fragment_id
        self._name_prefix = config.name_prefix
        self._param_entries = config.param_entries
        self._param_map = dict(self._param_entries)
        self._optimizer = config.optimizer
        self.state_key = config.state_key
        self._backup_device = config.backup_device
        self._pin_memory = config.pin_memory
        self._should_quantize = config.should_quantize
        self._current_sync_step: int | None = None

        self._original_state_tensors: dict[str, torch.Tensor] = {}
        self._averaged_state_tensors: list[tuple[str, torch.Tensor]] = []
        self._allreduce_work: list[Work] = []
        self._stream = torch.cuda.Stream() if torch.cuda.is_available() else None
        self._stop_event: torch.cuda.Event | None = None

        if config.bucket_cap_mb is not None:
            self.bucket_cap_mb = int(config.bucket_cap_mb * 1024 * 1024)

        if os.getenv(USE_BUCKETIZATION_ENV, "False") == "True":
            self.use_bucketization = True
        else:
            self.use_bucketization = config.use_bucketization

        self._init_backup_storage()
        self.save_state()

    def set_step_context(self, step: int) -> None:
        self._current_sync_step = step

    @property
    def fragment_id(self) -> int:
        return self._fragment_id

    @property
    def parameter_names(self) -> list[str]:
        return [name for name, _ in self._param_entries]

    def _init_backup_storage(self) -> None:
        for name, param in self._param_entries:
            state = self._optimizer.state.get(param, {})
            tensor = state.get(self.state_key)
            if isinstance(tensor, torch.Tensor):
                device = (
                    self._backup_device
                    if self._backup_device is not None
                    else tensor.device
                )
                backup = torch.empty_like(tensor, device=device)
                if (
                    self._pin_memory
                    and backup.device.type == "cpu"
                    and torch.cuda.is_available()
                    and not backup.is_pinned()
                ):
                    backup = backup.pin_memory()
                self._original_state_tensors[name] = backup

    def save_state(self) -> None:
        with torch.no_grad():
            for name, backup in self._original_state_tensors.items():
                param = self._param_map[name]
                tensor = self._optimizer.state[param][self.state_key]
                backup.copy_(tensor, non_blocking=True)

    def restore_state(self) -> None:
        with torch.no_grad():
            for name, backup in self._original_state_tensors.items():
                param = self._param_map[name]
                if (
                    param in self._optimizer.state
                    and self.state_key in self._optimizer.state[param]
                ):
                    self._optimizer.state[param][self.state_key].copy_(backup)

    def prepare_sync(self) -> None:
        if not self._original_state_tensors:
            return
        assert not self._allreduce_work
        if self._stream is not None:
            self._stream.wait_stream(torch.cuda.current_stream())

        logger.info(
            "DES-LOC streaming optimizer state '%s' fragment=%s sync starting (step=%s, manager_step=%s)",
            self.state_key,
            self._fragment_id,
            self._current_sync_step
            if self._current_sync_step is not None
            else "unknown",
            self._manager.current_step(),
        )

        context = (
            torch.cuda.stream(self._stream)
            if self._stream is not None
            else nullcontext()
        )
        with context:
            self._capture_states()
            self._allreduce_states()

    def _capture_states(self) -> None:
        self._averaged_state_tensors.clear()
        with torch.no_grad():
            for name in self._original_state_tensors:
                param = self._param_map[name]
                tensor = self._optimizer.state[param][self.state_key]
                self._averaged_state_tensors.append((name, tensor.detach().clone()))

    def _allreduce_states(self) -> None:
        tensors = [tensor for _, tensor in self._averaged_state_tensors]
        if not tensors:
            return
        if self.use_bucketization:
            self._bucketize_and_allreduce(tensors)
            return
        for tensor in tensors:
            work = self._manager.allreduce(
                tensor,
                should_quantize=self._should_quantize,
            )
            self._allreduce_work.append(work)

    def _bucketize_and_allreduce(self, tensors: list[torch.Tensor]) -> None:
        if not tensors:
            return

        bucket_size_bytes = self.bucket_cap_mb
        offset = 0
        flat_index = 0
        total_elems = sum(t.numel() for t in tensors)
        dtype = tensors[0].dtype
        device = tensors[0].device

        while offset < total_elems:
            chunk_elems = min(
                bucket_size_bytes // tensors[0].element_size(), total_elems - offset
            )
            flat_buffer = torch.zeros(chunk_elems, dtype=dtype, device=device)

            pack_offset = 0
            bucket_tensors: list[tuple[torch.Tensor, int, int]] = []
            for tensor in tensors[flat_index:]:
                numel = tensor.numel()
                if pack_offset + numel > chunk_elems:
                    break
                flat_buffer[pack_offset : pack_offset + numel].copy_(tensor.view(-1))
                bucket_tensors.append((tensor, pack_offset, numel))
                pack_offset += numel
                flat_index += 1

            work = self._manager.allreduce(
                flat_buffer,
                should_quantize=self._should_quantize,
            )

            def callback(
                fut: torch.futures.Future[list[torch.Tensor]],
            ) -> list[torch.Tensor]:
                for tensor, tensor_offset, numel in bucket_tensors:
                    tensor.copy_(
                        flat_buffer[tensor_offset : tensor_offset + numel].view_as(
                            tensor
                        )
                    )
                return []

            work.get_future().then(callback)
            self._allreduce_work.append(work)
            offset += chunk_elems

    def wait(self) -> None:
        if not self._allreduce_work:
            return
        if self._stream is not None and self._stop_event is not None:
            self._stop_event.synchronize()
            self._stop_event = None
        self._allreduce_work = []

    def perform_sync(self) -> None:
        if not self._averaged_state_tensors:
            return
        context = (
            torch.cuda.stream(self._stream)
            if self._stream is not None
            else nullcontext()
        )
        with context:
            for work in self._allreduce_work:
                work.wait()
            if self._stream is not None:
                self._stop_event = torch.cuda.Event()
                self._stop_event.record()
        self.wait()

        should_commit = self._manager.should_commit()
        if should_commit:
            self._apply_states()
            self.save_state()
        else:
            self.restore_state()
        self._averaged_state_tensors.clear()
        logger.info(
            "DES-LOC streaming optimizer state '%s' fragment=%s sync complete (commit=%s, step=%s, manager_step=%s)",
            self.state_key,
            self._fragment_id,
            should_commit,
            self._current_sync_step
            if self._current_sync_step is not None
            else "unknown",
            self._manager.current_step(),
        )
        self._current_sync_step = None

    def _apply_states(self) -> None:
        with torch.no_grad():
            for name, averaged in self._averaged_state_tensors:
                param = self._param_map[name]
                self._optimizer.state[param][self.state_key].copy_(averaged)

    def register_state_dict_fn(self) -> None:
        def load_fn(state_dict: dict[str, torch.Tensor]) -> None:
            for name, tensor in state_dict.items():
                if name in self._original_state_tensors:
                    self._original_state_tensors[name].copy_(tensor)

        def save_fn() -> dict[str, torch.Tensor]:
            return self._original_state_tensors

        self._manager.register_state_dict_fn(
            f"{self._name_prefix}_state_{self.state_key}",
            load_fn,
            save_fn,
        )


class _StreamingParameterFragment:
    """Streaming-enabled parameter fragment with asynchronous allreduce."""

    bucket_cap_mb: int = 1 * 1024 * 1024 * 1024
    use_bucketization: bool = False

    def __init__(
        self,
        *,
        manager,
        fragment_id: int,
        name_prefix: str,
        param_entries: list[tuple[str, nn.Parameter]],
        backup_device: torch.device | None,
        pin_memory: bool,
        outer_optimizer: Optimizer | None,
        inner_optimizer: Optimizer,
        fragment_sync_offset: int,
        fragment_sync_delay: int,
        sync_window: int,
        fragment_update_alpha: float,
        use_bucketization: bool,
        bucket_cap_mb: float | None,
        should_quantize: bool,
        log_outer_metrics: bool,
        metrics_logger: Callable[[dict[str, float]], None] | None,
        checkpoint_outer_optimizer: bool,
        pseudo_grad_top_k: float | None = None,
    ) -> None:
        self._manager = manager
        self._fragment_id = fragment_id
        self._name_prefix = name_prefix
        self._param_entries = param_entries
        self._param_map = dict(param_entries)
        self._backup_device = backup_device
        self._pin_memory = pin_memory
        self._outer_optimizer = outer_optimizer
        self._inner_optimizer = inner_optimizer
        self._fragment_sync_offset = fragment_sync_offset
        self._fragment_sync_delay = fragment_sync_delay
        self._sync_window = sync_window
        self._fragment_update_alpha = fragment_update_alpha
        self._log_outer_metrics = log_outer_metrics
        self._metrics_logger = metrics_logger
        self._averaging_only = outer_optimizer is None
        self._should_quantize = should_quantize
        self._checkpoint_outer_optimizer = checkpoint_outer_optimizer
        self._pseudo_grad_top_k = pseudo_grad_top_k
        self._current_sync_step: int | None = None

        self._grads: dict[str, torch.Tensor] = {}
        self._averaged_parameters: list[tuple[str, torch.Tensor]] = []
        self._local_parameters: dict[str, torch.Tensor] = {}
        self.original_parameters: dict[str, torch.Tensor] = {}

        self._allreduce_work: list[Work] = []
        self._stream = torch.cuda.Stream() if torch.cuda.is_available() else None
        self._stop_event: torch.cuda.Event | None = None

        if bucket_cap_mb is not None:
            self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)

        if os.getenv(USE_BUCKETIZATION_ENV, "False") == "True":
            self.use_bucketization = True
        else:
            self.use_bucketization = use_bucketization

        self._init_backup_storage()
        self.save_parameters()

    def set_metrics_logger(
        self, logger_fn: Callable[[dict[str, float]], None] | None
    ) -> None:
        self._metrics_logger = logger_fn

    def set_step_context(self, step: int) -> None:
        self._current_sync_step = step

    @property
    def parameter_names(self) -> list[str]:
        return [name for name, _ in self._param_entries]

    @property
    def fragment_id(self) -> int:
        return self._fragment_id

    @property
    def fragment_sync_offset(self) -> int:
        return self._fragment_sync_offset

    @property
    def fragment_sync_delay(self) -> int:
        return self._fragment_sync_delay

    def _named_parameters(self):
        yield from self._param_entries

    def _init_backup_storage(self) -> None:
        for name, param in self._named_parameters():
            local_tensor = _extract_local_tensor(param.data)
            device = (
                self._backup_device
                if self._backup_device is not None
                else local_tensor.device
            )
            backup = torch.empty_like(local_tensor, device=device)
            if (
                self._pin_memory
                and backup.device.type == "cpu"
                and torch.cuda.is_available()
            ):
                backup = backup.pin_memory()
            self.original_parameters[name] = backup

    def register_state_dict_fn(self) -> None:
        def load_fn(state_dict: dict[str, Any]) -> None:
            if not state_dict:
                self.save_parameters()
                return
            params_state = state_dict.get("original_parameters")
            if params_state is None:
                params_state = state_dict

            for name, tensor in params_state.items():
                if name in self.original_parameters:
                    self.original_parameters[name].copy_(tensor)

            if (
                self._outer_optimizer is not None
                and self._checkpoint_outer_optimizer
                and "outer_optimizer" in state_dict
            ):
                self._outer_optimizer.load_state_dict(state_dict["outer_optimizer"])

        def save_fn() -> dict[str, Any]:
            payload: dict[str, Any] = {
                "original_parameters": {
                    name: _extract_local_tensor(param)
                    for name, param in self.original_parameters.items()
                }
            }
            if self._outer_optimizer is not None and self._checkpoint_outer_optimizer:
                payload["outer_optimizer"] = self._outer_optimizer.state_dict()
            return payload

        self._manager.register_state_dict_fn(
            f"{self._name_prefix}_params",
            load_fn,
            save_fn,
        )

    def save_parameters(self) -> None:
        with torch.no_grad():
            for name, param in self._named_parameters():
                self.original_parameters[name].copy_(
                    _extract_local_tensor(param.data), non_blocking=True
                )

    def restore_parameters(self) -> None:
        with torch.no_grad():
            for name, param in self._named_parameters():
                _copy_into_tensor(param.data, self.original_parameters[name])

    def _save_local_parameters(self) -> None:
        with torch.no_grad():
            for name, param in self._named_parameters():
                self._local_parameters[name] = _extract_local_tensor(param.data)

    def _clear_local_parameters(self) -> None:
        self._local_parameters.clear()

    def _merge_parameters(self) -> None:
        if self._fragment_update_alpha <= 0 or not self._local_parameters:
            return
        with torch.no_grad():
            for name, param in self._named_parameters():
                local = self._local_parameters[name]
                if isinstance(param, DTensor):
                    param.data.lerp_(
                        DTensor.from_local(
                            local,
                            param.device_mesh,
                            param.placements,
                            shape=param.shape,
                            stride=param.stride(),
                        ),
                        self._fragment_update_alpha,
                    )
                else:
                    param.data.lerp_(local, self._fragment_update_alpha)

    def _save_grads(self) -> None:
        with torch.no_grad():
            for name, param in self._named_parameters():
                tensor = (
                    param.to_local()
                    if DTensor is not None and isinstance(param, DTensor)
                    else param
                )
                pseudo = self.original_parameters[name].to(tensor.device) - tensor
                # Apply top-k sparsity to pseudo-gradients before allreduce.
                _apply_topk_sparsity(pseudo, self._pseudo_grad_top_k)
                self._grads[name] = pseudo

    def _save_averaged_parameters(self) -> None:
        self._averaged_parameters.clear()
        # Determine if we should use sparse pseudo-gradient communication.
        use_sparse_pseudo_grads = (
            self._pseudo_grad_top_k is not None
            and 0.0 < self._pseudo_grad_top_k < 1.0
        )
        self._use_sparse_pseudo_grads = use_sparse_pseudo_grads

        with torch.no_grad():
            for name, param in self._named_parameters():
                if use_sparse_pseudo_grads:
                    # Compute pseudo-gradient and apply sparsity to it (not raw params).
                    tensor = (
                        param.to_local()
                        if DTensor is not None and isinstance(param, DTensor)
                        else param
                    )
                    reference = self.original_parameters[name].to(tensor.device)
                    pseudo_grad = reference - tensor
                    _apply_topk_sparsity(pseudo_grad, self._pseudo_grad_top_k)
                    self._averaged_parameters.append((name, pseudo_grad))
                else:
                    avg_param = _extract_local_tensor(param.data)
                    self._averaged_parameters.append((name, avg_param))

    def _set_grads(self) -> None:
        with torch.no_grad():
            for name, param in self._named_parameters():
                grad = self._grads.pop(name, None)
                if grad is None:
                    continue
                if isinstance(param, DTensor):
                    param.grad = DTensor.from_local(
                        grad,
                        param.device_mesh,
                        param.placements,
                        shape=param.shape,
                        stride=param.stride(),
                    )
                else:
                    param.grad = grad

    def _apply_averaged_parameters(self) -> None:
        with torch.no_grad():
            use_sparse = getattr(self, "_use_sparse_pseudo_grads", False)
            for name, avg_tensor in self._averaged_parameters:
                param = self._param_map[name]
                if use_sparse:
                    # avg_tensor is the averaged sparse pseudo-gradient.
                    # Reconstruct: new_param = reference - averaged_sparse_pseudo_grad
                    tensor = (
                        param.to_local()
                        if DTensor is not None and isinstance(param, DTensor)
                        else param
                    )
                    reference = self.original_parameters[name].to(tensor.device)
                    new_param = reference - avg_tensor
                    _copy_into_tensor(param.data, new_param)
                else:
                    # avg_tensor is the averaged parameter directly.
                    _copy_into_tensor(param.data, avg_tensor)
        self._averaged_parameters.clear()

    def wait(self) -> None:
        if not self._allreduce_work:
            return
        if self._stream is not None and self._stop_event is not None:
            self._stop_event.synchronize()
            self._stop_event = None
        self._allreduce_work = []

    def _bucketize_and_allreduce(self, tensors: list[torch.Tensor]) -> None:
        if not tensors:
            return

        bucket_size_bytes = self.bucket_cap_mb
        offset = 0
        flat_index = 0
        total_elems = sum(t.numel() for t in tensors)

        dtype = tensors[0].dtype
        device = tensors[0].device

        while offset < total_elems:
            chunk_elems = min(
                bucket_size_bytes // tensors[0].element_size(), total_elems - offset
            )
            flat_buffer = torch.zeros(chunk_elems, dtype=dtype, device=device)

            pack_offset = 0
            bucket_tensors: list[tuple[torch.Tensor, int, int]] = []
            for tensor in tensors[flat_index:]:
                numel = tensor.numel()
                if pack_offset + numel > chunk_elems:
                    break
                flat_buffer[pack_offset : pack_offset + numel].copy_(tensor.view(-1))
                bucket_tensors.append((tensor, pack_offset, numel))
                pack_offset += numel
                flat_index += 1

            work = self._manager.allreduce(
                flat_buffer,
                should_quantize=self._should_quantize,
            )

            def callback(
                fut: torch.futures.Future[list[torch.Tensor]],
            ) -> list[torch.Tensor]:
                for tensor, tensor_offset, numel in bucket_tensors:
                    tensor.copy_(
                        flat_buffer[tensor_offset : tensor_offset + numel].view_as(
                            tensor
                        )
                    )
                return []

            work.get_future().then(callback)
            self._allreduce_work.append(work)
            offset += chunk_elems

    def _allreduce_grads(self) -> None:
        tensors = list(self._grads.values())
        if not tensors:
            return
        if self.use_bucketization:
            self._bucketize_and_allreduce(tensors)
            return
        for tensor in tensors:
            work = self._manager.allreduce(
                tensor,
                should_quantize=self._should_quantize,
            )
            self._allreduce_work.append(work)

    def _allreduce_parameters(self) -> None:
        tensors = [tensor for _, tensor in self._averaged_parameters]
        if not tensors:
            return
        if self.use_bucketization:
            self._bucketize_and_allreduce(tensors)
            return
        for tensor in tensors:
            work = self._manager.allreduce(
                tensor,
                should_quantize=self._should_quantize,
            )
            self._allreduce_work.append(work)

    def prepare_sync(self) -> None:
        assert not self._allreduce_work
        if self._stream is not None:
            self._stream.wait_stream(torch.cuda.current_stream())

        logger.info(
            "DES-LOC streaming parameter fragment=%s sync starting (step=%s, manager_step=%s)",
            self._fragment_id,
            self._current_sync_step
            if self._current_sync_step is not None
            else "unknown",
            self._manager.current_step(),
        )

        context = (
            torch.cuda.stream(self._stream)
            if self._stream is not None
            else nullcontext()
        )
        with context:
            if self._averaging_only:
                self._save_averaged_parameters()
                self._allreduce_parameters()
            else:
                self._save_grads()
                self._allreduce_grads()

    def _zero_outer_optimizer_grads(self) -> None:
        _zero_optimizer_grads(self._outer_optimizer)

    def _emit_outer_metrics(
        self, pseudo_norm_sq: float, momentum_norm_sq: float, has_momentum: bool
    ) -> None:
        if not self._log_outer_metrics or self._metrics_logger is None:
            return
        metrics: dict[str, float] = {}
        metrics["desloc_outer/pseudo_grad_l2"] = math.sqrt(max(pseudo_norm_sq, 0.0))
        if has_momentum:
            metrics["desloc_outer/momentum_l2"] = math.sqrt(max(momentum_norm_sq, 0.0))
        try:
            self._metrics_logger(metrics)
        except Exception:  # pragma: no cover - diagnostics only
            logger.exception("DES-LOC streaming metrics logger failed.")

    def perform_sync(self) -> bool:
        assert self._allreduce_work
        context = (
            torch.cuda.stream(self._stream)
            if self._stream is not None
            else nullcontext()
        )
        with context:
            for work in self._allreduce_work:
                work.wait()
            if self._stream is not None:
                self._stop_event = torch.cuda.Event()
                self._stop_event.record()

        self.wait()

        if not self._averaging_only:
            self._save_local_parameters()

        self.restore_parameters()
        should_commit = self._manager.should_commit()

        if should_commit:
            if self._averaging_only:
                self._apply_averaged_parameters()
                self.save_parameters()
            else:
                self._set_grads()
                self._outer_optimizer.step()
                self.save_parameters()
                self._merge_parameters()
                pseudo_norm_sq = 0.0
                momentum_norm_sq = 0.0
                has_momentum = False
                if self._log_outer_metrics:
                    for name, param in self._named_parameters():
                        grad = param.grad
                        if grad is not None:
                            pseudo_norm_sq += grad.pow(2).sum().item()
                    if isinstance(self._outer_optimizer, torch.optim.SGD):
                        for state in self._outer_optimizer.state.values():
                            buffer = state.get("momentum_buffer")
                            if isinstance(buffer, torch.Tensor):
                                has_momentum = True
                                momentum_norm_sq += buffer.pow(2).sum().item()
                self._emit_outer_metrics(pseudo_norm_sq, momentum_norm_sq, has_momentum)

            self._zero_outer_optimizer_grads()
        else:
            self.restore_parameters()

        self._clear_local_parameters()
        self._grads.clear()
        self._averaged_parameters.clear()

        logger.info(
            "DES-LOC streaming parameter fragment=%s sync complete (commit=%s, step=%s, manager_step=%s)",
            self._fragment_id,
            should_commit,
            self._current_sync_step
            if self._current_sync_step is not None
            else "unknown",
            self._manager.current_step(),
        )
        self._current_sync_step = None

        return should_commit


@dataclass
class _StreamingFragmentSchedule:
    fragment: _StreamingParameterFragment
    next_prepare_step: int
    next_sync_step: int
    pending: bool = False

    def advance(self, sync_window: int) -> None:
        self.next_prepare_step += sync_window
        self.next_sync_step += sync_window


class DesLocController:
    """Attach DES-LOC synchronization hooks to a PyTorch optimizer."""

    _EXCLUDED_STATE_KEYS: set[str] = {
        # Local-only GaLore state; should not be synchronized across replicas.
        "error_feedback",
        # Projection tensors are provided out-of-band; syncing them risks clobbering server state.
        "projector_basis",
    }

    def __init__(self, config: DesLocControllerConfig) -> None:
        self._manager = config.manager
        self._model = config.model
        self._optimizer = config.optimizer
        self._backup_device = config.backup_device
        self._pin_memory = config.pin_memory
        self._name_prefix = config.name_prefix
        self._raw_optimizer_sync_config = config.optimizer_sync_every
        self._quorum_timeout = timedelta(seconds=max(1, config.quorum_timeout_seconds))
        self._optimizer_state_sync_enabled = not config.disable_optimizer_state_sync
        self._local_step = 0

        param_fragment_cfg = ParameterFragmentConfig(
            manager=config.manager,
            model=config.model,
            sync_every=config.param_sync_every,
            backup_device=config.backup_device,
            pin_memory=config.pin_memory,
            name_prefix=config.name_prefix,
            outer_optimizer=config.outer_optimizer,
            local_optimizer=config.optimizer,
            log_outer_metrics=config.log_outer_metrics,
            metrics_logger=config.metrics_logger,
            checkpoint_outer_optimizer=config.checkpoint_outer_optimizer,
            low_rank_server_update=config.low_rank_server_update,
            outer_optimizer_low_rank=config.outer_optimizer_low_rank,
            low_rank_projector_error_feedback=config.low_rank_projector_error_feedback,
            low_rank_projector_source=config.low_rank_projector_source,
            pseudo_grad_top_k=config.pseudo_grad_top_k,
        )
        fragment_cls = (
            _OuterOptimizingParameterFragment
            if param_fragment_cfg.outer_optimizer is not None
            else _ParameterFragment
        )
        self._param_fragment = fragment_cls(param_fragment_cfg)
        self._param_fragment.register_state_dict_fn()

        self._fragments: list[_BaseFragment] = [self._param_fragment]
        self._allreduce_work: list[Any] = []
        self._is_opt_init = not self._optimizer_state_sync_enabled

        self._hook = config.optimizer.register_step_post_hook(self._step_post_hook)
        self._register_state_dict_fn()

    def close(self) -> None:
        """Detach the registered optimizer step hook."""
        if self._hook is not None:
            self._hook.remove()
            self._hook = None

    def set_metrics_logger(
        self, logger_fn: Callable[[dict[str, float]], None] | None
    ) -> None:
        self._param_fragment.set_metrics_logger(logger_fn)

    def _resolve_optimizer_sync_intervals(self, state_keys: Iterable[str]) -> list[int]:
        keys = list(state_keys)
        if not keys:
            return []

        spec = self._raw_optimizer_sync_config
        if spec is None:
            return [self._param_fragment.sync_every for _ in keys]
        if isinstance(spec, int):
            return self._expand_single_interval(spec, keys)
        if isinstance(spec, list):
            return self._expand_list_intervals(spec, keys)
        if isinstance(spec, dict):
            return self._expand_dict_intervals(spec, keys)

        msg = f"optimizer_sync_every must be an int, list, dict, or None; received {type(spec)!r}"
        raise TypeError(msg)

    def _expand_single_interval(self, interval: int, keys: list[str]) -> list[int]:
        self._validate_positive_interval(interval)
        return [interval for _ in keys]

    def _expand_list_intervals(
        self, intervals: list[int], keys: list[str]
    ) -> list[int]:
        if not intervals:
            return [self._param_fragment.sync_every for _ in keys]

        # Allow under-specification by repeating the last provided interval, and allow
        # over-specification by truncating extras. This keeps runs tolerant to optional
        # optimizer states being added or removed (e.g., error_feedback filtered out).
        if len(intervals) < len(keys):
            pad = [intervals[-1] for _ in range(len(keys) - len(intervals))]
            normalized = [int(value) for value in intervals + pad]
        else:
            normalized = [int(value) for value in intervals[: len(keys)]]
        for value in normalized:
            self._validate_positive_interval(value)
        return normalized

    def _expand_dict_intervals(
        self, mapping: dict[str, int], keys: list[str]
    ) -> list[int]:
        resolved: list[int] = []
        for key in keys:
            if key not in mapping:
                msg = f"Missing DES-LOC sync interval for optimizer state '{key}'."
                raise ValueError(msg)
            value = int(mapping[key])
            self._validate_positive_interval(value)
            resolved.append(value)
        return resolved

    def _validate_positive_interval(self, value: int) -> None:
        if value <= 0:
            msg = "optimizer_sync_every values must be positive"
            raise ValueError(msg)

    def _lazy_init_optimizer_fragments(self) -> None:
        if not self._optimizer_state_sync_enabled:
            self._is_opt_init = True
            return
        state_sets = set()
        for state in self._optimizer.state.values():
            for key, value in state.items():
                if key in self._EXCLUDED_STATE_KEYS:
                    continue
                if isinstance(value, torch.Tensor) and value.numel() > 1:
                    state_sets.add(str(key))

        state_keys = sorted(state_sets)
        sync_intervals = self._resolve_optimizer_sync_intervals(state_keys)

        if not state_keys and self._raw_optimizer_sync_config is not None:
            logger.warning(
                "DES-LOC optimizer_sync_every provided but no tensor states were discovered; skipping state synchronization."
            )

        for idx, key in enumerate(state_keys):
            fragment_config = OptimizerFragmentConfig(
                manager=self._manager,
                model=self._model,
                optimizer=self._optimizer,
                state_key=key,
                sync_every=sync_intervals[idx],
                backup_device=self._backup_device,
                name_prefix=f"{self._name_prefix}_{key}",
            )
            fragment = _OptimizerStateFragment(fragment_config)
            fragment.register_state_dict_fn()
            self._fragments.append(fragment)

        self._is_opt_init = True

    def _step_post_hook(
        self,
        _optimizer: Optimizer,
        _args: tuple[Any, ...],
        _kwargs: dict[str, Any],
    ) -> None:
        self._local_step += 1
        if not self._is_opt_init:
            self._lazy_init_optimizer_fragments()

        ready_fragments = [fragment for fragment in self._fragments if fragment.tick()]

        if ready_fragments:
            self._sync(ready_fragments)

    def _sync(self, fragments: list[_BaseFragment]) -> None:
        self._manager.disallow_state_dict_read()
        try:
            try:
                self._manager.start_quorum(
                    allow_heal=False,
                    shrink_only=False,
                    timeout=self._quorum_timeout,
                )
            except TimeoutError as err:
                logger.warning(
                    "DES-LOC quorum timed out after %.1f seconds; skipping synchronization.",
                    self._quorum_timeout.total_seconds(),
                )
                self._manager.report_error(err)
                for fragment in fragments:
                    fragment.restore_state()
                    fragment.reset()
                return

            step_ctx = self._local_step
            for fragment in fragments:
                setter = getattr(fragment, "set_step_context", None)
                if callable(setter):
                    setter(step_ctx)

            self._prepare_sync(fragments)
            self._perform_sync(fragments)
            for fragment in fragments:
                fragment.reset()
        finally:
            self._manager.allow_state_dict_read()

    def _prepare_sync(self, fragments: list[_BaseFragment]) -> None:
        self._allreduce_work.clear()
        for fragment in fragments:
            self._allreduce_work.extend(fragment.prepare_sync())

    def _perform_sync(self, fragments: list[_BaseFragment]) -> None:
        for work in self._allreduce_work:
            work.wait()

        commit_allowed = self._manager.should_commit()

        if commit_allowed:
            for fragment in fragments:
                fragment.perform_sync()
                fragment.save_state()
        else:
            for fragment in fragments:
                fragment.restore_state()

    def _register_state_dict_fn(self) -> None:
        def load_fn(state: dict[str, Any]) -> None:
            step_value = state.get("controller_step")
            if isinstance(step_value, torch.Tensor):
                self._local_step = int(step_value.item())
            elif isinstance(step_value, (int, float)):
                self._local_step = int(step_value)
            else:
                self._local_step = 0

            fragment_steps_obj = state.get("fragment_steps")
            fragment_steps = (
                fragment_steps_obj.tolist()
                if isinstance(fragment_steps_obj, torch.Tensor)
                else []
            )

            for idx, fragment in enumerate(self._fragments):
                if idx < len(fragment_steps):
                    fragment._local_step = int(fragment_steps[idx])
                else:
                    fragment._local_step = self._local_step % fragment.sync_every

        def save_fn() -> dict[str, torch.Tensor]:
            fragment_steps = [getattr(fragment, "_local_step", 0) for fragment in self._fragments]
            return {
                "controller_step": torch.tensor(int(self._local_step), dtype=torch.int64),
                "fragment_steps": torch.tensor(fragment_steps, dtype=torch.int64),
            }

        self._manager.register_state_dict_fn(
            f"{self._name_prefix}_controller_meta",
            load_fn,
            save_fn,
        )


class StreamingDesLocController:
    """Streaming DES-LOC controller which mirrors TorchFT's Streaming DiLoCo."""

    _EXCLUDED_STATE_KEYS: set[str] = {
        # Local-only GaLore state; should not be synchronized across replicas.
        "error_feedback",
        # Projection tensors are provided out-of-band; syncing them risks clobbering server state.
        "projector_basis",
    }

    def __init__(
        self,
        config: DesLocControllerConfig,
        streaming: DesLocStreamingConfig,
    ) -> None:
        self._manager = config.manager
        self._model = config.model
        self._optimizer = config.optimizer
        self._backup_device = config.backup_device
        self._pin_memory = config.pin_memory
        self._name_prefix = config.name_prefix
        self._raw_optimizer_sync_config = config.optimizer_sync_every
        self._quorum_timeout = timedelta(seconds=max(1, config.quorum_timeout_seconds))
        self._log_outer_metrics = config.log_outer_metrics
        self._metrics_logger = config.metrics_logger
        self._checkpoint_outer_optimizer = config.checkpoint_outer_optimizer
        self._streaming_cfg = streaming
        self._optimizer_state_sync_enabled = not config.disable_optimizer_state_sync
        self._pseudo_grad_top_k = config.pseudo_grad_top_k

        fragment_strategy = getattr(streaming, "fragment_strategy", "strided")
        custom_fragments = getattr(streaming, "custom_fragments", None)
        if fragment_strategy == "custom" and not custom_fragments:
            msg = "desloc.streaming.custom_fragments must be provided when using the 'custom' strategy."
            raise ValueError(msg)

        partitions = _partition_named_parameters(
            self._model,
            streaming.fragments,
            strategy=fragment_strategy,
            custom_fragments=custom_fragments,
        )
        if not partitions:
            msg = "DES-LOC streaming requires at least one model parameter."
            raise ValueError(msg)

        if not streaming.separate_non_layer_fragment:
            before_len = len(partitions)
            partitions = _merge_non_layer_partition(partitions)
            if len(partitions) < before_len:
                logger.info(
                    "DES-LOC streaming merged non-layer parameters into fragment 0."
                )

        layer_fragment_indices = list(range(len(partitions)))

        layer_fragment_count = len(layer_fragment_indices)
        num_fragments = len(partitions)

        if config.param_sync_every < layer_fragment_count:
            msg = (
                "desloc.param_sync_every must be >= the number of streaming fragments."
            )
            raise ValueError(msg)
        if config.param_sync_every % layer_fragment_count != 0:
            msg = "desloc.param_sync_every must be divisible by the number of streaming fragments."
            raise ValueError(msg)

        self._sync_window = config.param_sync_every
        self._fragment_stride = self._sync_window // layer_fragment_count
        if streaming.sync_delay >= self._fragment_stride:
            msg = "desloc.streaming.sync_delay must be smaller than param_sync_every / fragments."
            raise ValueError(msg)
        if not (0.0 <= streaming.update_alpha <= 1.0):
            msg = "desloc.streaming.update_alpha must be between 0 and 1."
            raise ValueError(msg)

        outer_handles = self._build_outer_optimizer_handles(
            config.outer_optimizer, partitions
        )
        self._partitions = partitions
        self._fragment_sync_delay = streaming.sync_delay
        layer_offsets = self._resolve_fragment_offsets(layer_fragment_count, streaming)
        fragment_offsets = self._assign_fragment_offsets(
            num_fragments, layer_fragment_indices, layer_offsets
        )
        outer_checkpoint_flags = self._build_outer_checkpoint_flags(outer_handles)
        self._schedule_entries: list[_StreamingFragmentSchedule] = []
        self._fragments: list[_StreamingParameterFragment] = []
        for idx, (params, offset) in enumerate(
            zip(partitions, fragment_offsets, strict=True)
        ):
            fragment = _StreamingParameterFragment(
                manager=self._manager,
                fragment_id=idx,
                name_prefix=f"{self._name_prefix}_fragment_{idx}",
                param_entries=params,
                backup_device=self._backup_device,
                pin_memory=self._pin_memory,
                outer_optimizer=outer_handles[idx],
                inner_optimizer=self._optimizer,
                fragment_sync_offset=offset,
                fragment_sync_delay=self._fragment_sync_delay,
                sync_window=self._sync_window,
                fragment_update_alpha=streaming.update_alpha,
                use_bucketization=streaming.use_bucketization,
                bucket_cap_mb=streaming.bucket_cap_mb,
                should_quantize=streaming.should_quantize,
                log_outer_metrics=self._log_outer_metrics,
                metrics_logger=self._metrics_logger,
                checkpoint_outer_optimizer=(
                    self._checkpoint_outer_optimizer and outer_checkpoint_flags[idx]
                ),
                pseudo_grad_top_k=self._pseudo_grad_top_k,
            )
            param_names = fragment.parameter_names
            logger.info(
                "DES-LOC streaming parameter fragment=%s initialized with %d parameters: %s",
                idx,
                len(param_names),
                _format_fragment_membership(param_names),
            )
            prepare_step = max(offset - self._fragment_sync_delay, 0)
            schedule_entry = _StreamingFragmentSchedule(
                fragment=fragment,
                next_prepare_step=prepare_step,
                next_sync_step=offset,
            )
            self._schedule_entries.append(schedule_entry)
            self._fragments.append(fragment)
        self._hooks: list[RemovableHandle] = []
        self._hooks.append(self._optimizer.register_step_pre_hook(self._step_pre_hook))
        self._hooks.append(
            self._optimizer.register_step_post_hook(self._step_post_hook)
        )

        self._inner_step = 0
        self._state_cursor = 0
        self._optimizer_state_log_emitted = False
        self._optimizer_state_schedule = streaming.optimizer_state_schedule

        self._state_fragments_per_fragment: list[
            list[_StreamingOptimizerStateFragment]
        ] = []
        self._is_opt_init = not self._optimizer_state_sync_enabled
        self._fragments_synced_this_step: set[int] = set()
        self._pending_aligned_state_frags: dict[
            int, list[tuple[_StreamingOptimizerStateFragment, int]]
        ] = {}

        self._register_state_dict_functions()
        self._log_parameter_fragment_assignments()

    def close(self) -> None:
        for hook in self._hooks:
            hook.remove()
        self._hooks.clear()

    def set_metrics_logger(
        self, logger_fn: Callable[[dict[str, float]], None] | None
    ) -> None:
        self._metrics_logger = logger_fn
        for fragment in self._fragments:
            fragment.set_metrics_logger(logger_fn if self._log_outer_metrics else None)

    def _register_state_dict_functions(self) -> None:
        for fragment in self._fragments:
            fragment.register_state_dict_fn()

    def _log_parameter_fragment_assignments(self) -> None:
        mapping: dict[str, int] = {}
        for fragment in self._fragments:
            for name in fragment.parameter_names:
                key = _component_key_from_name(name)
                existing = mapping.get(key)
                if existing is not None and existing != fragment.fragment_id:
                    logger.warning(
                        "DES-LOC streaming parameter component %s mapped to multiple fragments (%s, %s).",
                        key,
                        existing,
                        fragment.fragment_id,
                    )
                mapping[key] = fragment.fragment_id

        if not mapping:
            logger.info("DES-LOC streaming parameter fragments: none discovered.")
            return

        formatted = "; ".join(
            f"{component}->frag{fragment_id}"
            for component, fragment_id in sorted(mapping.items())
        )
        logger.info("DES-LOC streaming parameter fragments: %s", formatted)

    def _build_outer_optimizer_handles(
        self,
        outer_spec: DesLocOuterOptimizerConfig | Optimizer | list[Optimizer] | None,
        partitions: list[list[tuple[str, nn.Parameter]]],
    ) -> list[Optimizer | None]:
        if outer_spec is None:
            return [None for _ in partitions]
        if isinstance(outer_spec, list):
            if len(outer_spec) != len(partitions):
                msg = "When providing a list of outer optimizers, its length must match desloc.streaming.fragments."
                raise ValueError(msg)
            return outer_spec
        if isinstance(outer_spec, Optimizer):
            return [outer_spec for _ in partitions]
        if isinstance(outer_spec, DesLocOuterOptimizerConfig):
            handles: list[Optimizer] = []
            optimizer_cls = outer_spec.resolve_optimizer_cls()
            for params in partitions:
                trainable = [param for _, param in params if param.requires_grad]
                if not trainable:
                    msg = "DES-LOC outer optimizer requires at least one trainable parameter per fragment."
                    raise ValueError(msg)
                handles.append(optimizer_cls(trainable, **outer_spec.kwargs))
            return handles
        msg = "desloc.outer_optimizer must be a config, Optimizer, list of Optimizers, or None."
        raise TypeError(msg)

    def _build_outer_checkpoint_flags(
        self, outer_handles: list[Optimizer | None]
    ) -> list[bool]:
        if not self._checkpoint_outer_optimizer:
            return [False for _ in outer_handles]
        seen: set[int] = set()
        flags: list[bool] = []
        for optimizer in outer_handles:
            if optimizer is None:
                flags.append(False)
                continue
            ident = id(optimizer)
            if ident in seen:
                flags.append(False)
                continue
            seen.add(ident)
            flags.append(True)
        return flags

    def _resolve_fragment_offsets(
        self,
        num_fragments: int,
        streaming: DesLocStreamingConfig,
    ) -> list[int]:
        fragment_sync_offsets = getattr(streaming, "fragment_sync_offsets", None)
        if fragment_sync_offsets is None:
            stride = self._sync_window / num_fragments
            offsets = [
                max(1, math.floor(stride * (idx + 1))) for idx in range(num_fragments)
            ]
            offsets[-1] = self._sync_window
        else:
            offsets = [int(value) for value in fragment_sync_offsets]
            if len(offsets) != num_fragments:
                msg = "desloc.streaming.fragment_sync_offsets must match the fragment count."
                raise ValueError(msg)

        if offsets != sorted(offsets):
            msg = "desloc.streaming.fragment_sync_offsets must be strictly increasing."
            raise ValueError(msg)
        if offsets[0] <= 0 or offsets[-1] > self._sync_window:
            msg = "desloc.streaming.fragment_sync_offsets must lie within the sync window."
            raise ValueError(msg)
        for offset in offsets:
            if offset <= self._fragment_sync_delay:
                msg = (
                    "Each fragment sync offset must exceed desloc.streaming.sync_delay."
                )
                raise ValueError(msg)
        return offsets

    @staticmethod
    def _assign_fragment_offsets(
        total_fragments: int,
        layer_fragment_indices: list[int],
        layer_offsets: list[int],
    ) -> list[int]:
        offset_map: dict[int, int] = {}
        for slot, fragment_idx in enumerate(layer_fragment_indices):
            offset_map[fragment_idx] = layer_offsets[slot]
        default_offset = layer_offsets[0]
        for fragment_idx in range(total_fragments):
            offset_map.setdefault(fragment_idx, default_offset)
        return [offset_map[idx] for idx in range(total_fragments)]

    def _drive_fragment_schedule(self) -> None:
        if not self._schedule_entries:
            return
        for entry in self._schedule_entries:
            if not entry.pending and self._inner_step == entry.next_prepare_step:
                self._attempt_prepare_fragment(entry)
            if entry.pending and self._inner_step == entry.next_sync_step:
                self._complete_fragment_sync(entry)

    def _attempt_prepare_fragment(self, entry: _StreamingFragmentSchedule) -> None:
        fragment = entry.fragment
        try:
            self._manager.start_quorum(
                allow_heal=False,
                shrink_only=False,
                timeout=self._quorum_timeout,
            )
        except TimeoutError as err:
            logger.warning(
                "DES-LOC streaming quorum timed out after %.1f seconds; skipping synchronization.",
                self._quorum_timeout.total_seconds(),
            )
            self._manager.report_error(err)
            fragment.restore_parameters()
            entry.advance(self._sync_window)
            return

        logger.info(
            "Preparing fragment=%s step=%s",
            fragment.fragment_id,
            self._inner_step,
        )
        fragment.set_step_context(self._inner_step)
        fragment.prepare_sync()
        self._maybe_prepare_aligned_state_sync(fragment.fragment_id)
        entry.pending = True

    def _complete_fragment_sync(self, entry: _StreamingFragmentSchedule) -> None:
        fragment = entry.fragment
        logger.info(
            "Syncing fragment=%s step=%s manager_step=%s",
            fragment.fragment_id,
            self._inner_step,
            self._manager.current_step(),
        )
        fragment.perform_sync()
        entry.pending = False
        self._fragments_synced_this_step.add(fragment.fragment_id)
        entry.advance(self._sync_window)

    def _maybe_prepare_aligned_state_sync(self, fragment_idx: int) -> None:
        if not self._optimizer_state_sync_enabled:
            return
        if self._optimizer_state_schedule != "aligned":
            return
        commit_step = self._inner_step + self._fragment_sync_delay
        ready = self._resolve_aligned_state_candidates(
            fragment_idx,
            commit_step=commit_step,
        )
        if not ready:
            self._pending_aligned_state_frags.pop(fragment_idx, None)
            return
        entries: list[tuple[_StreamingOptimizerStateFragment, int]] = []
        for state_fragment in ready:
            state_fragment.set_step_context(self._inner_step)
            state_fragment.prepare_sync()
            entries.append((state_fragment, commit_step))
        self._pending_aligned_state_frags[fragment_idx] = entries

    def _drive_aligned_state_completion(self) -> None:
        if not self._optimizer_state_sync_enabled:
            return
        if self._optimizer_state_schedule != "aligned":
            return
        if not self._pending_aligned_state_frags:
            return
        current_step = self._inner_step
        for fragment_idx in list(self._pending_aligned_state_frags.keys()):
            entries = self._pending_aligned_state_frags.get(fragment_idx)
            if not entries:
                self._pending_aligned_state_frags.pop(fragment_idx, None)
                continue
            completed: list[tuple[_StreamingOptimizerStateFragment, int]] = []
            remaining: list[tuple[_StreamingOptimizerStateFragment, int]] = []
            for state_fragment, commit_step in entries:
                if current_step >= commit_step:
                    completed.append((state_fragment, commit_step))
                else:
                    remaining.append((state_fragment, commit_step))
            for state_fragment, _commit in completed:
                state_fragment.perform_sync()
                state_fragment.reset()
            if remaining:
                self._pending_aligned_state_frags[fragment_idx] = remaining
            else:
                self._pending_aligned_state_frags.pop(fragment_idx, None)

    def _resolve_aligned_state_candidates(
        self,
        fragment_idx: int,
        *,
        commit_step: int,
    ) -> list[_StreamingOptimizerStateFragment]:
        if not self._state_fragments_per_fragment:
            return []
        if fragment_idx >= len(self._state_fragments_per_fragment):
            return []
        states = self._state_fragments_per_fragment[fragment_idx]
        if not states:
            return []

        fragment = self._fragments[fragment_idx]
        offset = fragment.fragment_sync_offset
        if commit_step < offset:
            return []

        ready: list[_StreamingOptimizerStateFragment] = []
        for state_fragment in states:
            interval = max(1, state_fragment.sync_every)
            if (commit_step - offset) % interval == 0:
                ready.append(state_fragment)
        return ready

    def _step_pre_hook(
        self,
        _optimizer: Optimizer,
        _args: tuple[Any, ...],
        _kwargs: dict[str, Any],
    ) -> None:
        self._manager.disallow_state_dict_read()

    def _step_post_hook(
        self,
        _optimizer: Optimizer,
        _args: tuple[Any, ...],
        _kwargs: dict[str, Any],
    ) -> None:
        self._manager.allow_state_dict_read()
        if not self._is_opt_init:
            self._lazy_init_optimizer_fragments()
        self._inner_step += 1
        self._drive_fragment_schedule()

        if not self._fragments or not self._state_fragments_per_fragment:
            self._fragments_synced_this_step.clear()
            self._pending_aligned_state_frags.clear()
            return

        if self._optimizer_state_schedule == "aligned":
            self._drive_aligned_state_completion()
            self._fragments_synced_this_step.clear()
            return

        synced_fragments = tuple(self._fragments_synced_this_step)
        self._fragments_synced_this_step.clear()

        if not synced_fragments:
            self._drive_staggered_state_schedule()

    def _resolve_optimizer_sync_intervals(self, state_keys: Iterable[str]) -> list[int]:
        keys = list(state_keys)
        if not keys:
            return []

        spec = self._raw_optimizer_sync_config
        if spec is None:
            return [self._fragment_stride for _ in keys]
        if isinstance(spec, int):
            return self._expand_single_interval(spec, keys)
        if isinstance(spec, list):
            return self._expand_list_intervals(spec, keys)
        if isinstance(spec, dict):
            return self._expand_dict_intervals(spec, keys)

        msg = f"optimizer_sync_every must be an int, list, dict, or None; received {type(spec)!r}"
        raise TypeError(msg)

    def _expand_single_interval(self, interval: int, keys: list[str]) -> list[int]:
        self._validate_positive_interval(interval)
        return [interval for _ in keys]

    def _expand_list_intervals(
        self, intervals: list[int], keys: list[str]
    ) -> list[int]:
        if not intervals:
            return [self._fragment_stride for _ in keys]

        # Tolerate optional optimizer states by padding with the last interval when under-specified
        # and truncating extras when over-specified.
        if len(intervals) < len(keys):
            pad = [intervals[-1] for _ in range(len(keys) - len(intervals))]
            normalized = [int(value) for value in intervals + pad]
        else:
            normalized = [int(value) for value in intervals[: len(keys)]]
        for value in normalized:
            self._validate_positive_interval(value)
        return normalized

    def _expand_dict_intervals(
        self, mapping: dict[str, int], keys: list[str]
    ) -> list[int]:
        resolved: list[int] = []
        for key in keys:
            if key not in mapping:
                msg = f"Missing DES-LOC sync interval for optimizer state '{key}'."
                raise ValueError(msg)
            value = int(mapping[key])
            self._validate_positive_interval(value)
            resolved.append(value)
        return resolved

    def _validate_positive_interval(self, value: int) -> None:
        if value <= 0:
            msg = "optimizer_sync_every values must be positive"
            raise ValueError(msg)

    def _lazy_init_optimizer_fragments(self) -> None:
        if not self._optimizer_state_sync_enabled:
            self._state_fragments_per_fragment = [[] for _ in self._fragments]
            self._is_opt_init = True
            return
        state_sets: set[str] = set()
        for state in self._optimizer.state.values():
            for key, value in state.items():
                if key in self._EXCLUDED_STATE_KEYS:
                    continue
                if isinstance(value, torch.Tensor) and value.numel() > 1:
                    state_sets.add(str(key))

        state_keys = sorted(state_sets)
        sync_intervals = self._resolve_optimizer_sync_intervals(state_keys)

        if not state_keys and self._raw_optimizer_sync_config is not None:
            logger.warning(
                "DES-LOC optimizer_sync_every provided but no tensor states were discovered; skipping state synchronization."
            )

        if not state_keys:
            self._state_fragments_per_fragment = [[] for _ in self._fragments]
            self._is_opt_init = True
            return

        self._state_fragments_per_fragment = [[] for _ in self._fragments]
        for idx, key in enumerate(state_keys):
            sync_every = sync_intervals[idx]
            for fragment_idx, params in enumerate(self._partitions):
                fragment_config = StreamingOptimizerFragmentConfig(
                    manager=self._manager,
                    fragment_id=fragment_idx,
                    name_prefix=f"{self._name_prefix}_{key}_fragment_{fragment_idx}",
                    param_entries=params,
                    optimizer=self._optimizer,
                    state_key=key,
                    sync_every=sync_every,
                    backup_device=self._backup_device,
                    pin_memory=self._pin_memory,
                    use_bucketization=self._streaming_cfg.use_bucketization,
                    bucket_cap_mb=self._streaming_cfg.bucket_cap_mb,
                    should_quantize=self._streaming_cfg.should_quantize,
                )
                fragment = _StreamingOptimizerStateFragment(fragment_config)
                param_names = fragment.parameter_names
                logger.info(
                    "DES-LOC streaming optimizer state '%s' fragment=%s initialized with %d parameters: %s",
                    key,
                    fragment_idx,
                    len(param_names),
                    _format_fragment_membership(param_names),
                )
                fragment.register_state_dict_fn()
                self._state_fragments_per_fragment[fragment_idx].append(fragment)

        self._is_opt_init = True
        self._log_optimizer_state_fragment_assignments()

    def _sync_state_fragments(
        self, fragment_idx: int, *, limit_one: bool = False
    ) -> None:
        if not self._optimizer_state_sync_enabled:
            return
        if not self._state_fragments_per_fragment:
            return
        if fragment_idx >= len(self._state_fragments_per_fragment):
            return

        candidates = self._state_fragments_per_fragment[fragment_idx]
        ready: list[_StreamingOptimizerStateFragment] = []
        for fragment in candidates:
            if fragment.tick():
                ready.append(fragment)
                if limit_one:
                    break

        self._execute_state_sync_batch(ready)

    def _execute_state_sync_batch(
        self, fragments: list[_StreamingOptimizerStateFragment]
    ) -> None:
        if not fragments:
            return
        try:
            self._manager.start_quorum(
                allow_heal=False,
                shrink_only=False,
                timeout=self._quorum_timeout,
            )
        except TimeoutError as err:
            logger.warning(
                "DES-LOC optimizer state quorum timed out after %.1f seconds; skipping synchronization.",
                self._quorum_timeout.total_seconds(),
            )
            self._manager.report_error(err)
            for fragment in fragments:
                fragment.restore_state()
                fragment.reset()
            return

        for fragment in fragments:
            fragment.set_step_context(self._inner_step)
            fragment.prepare_sync()
        for fragment in fragments:
            fragment.perform_sync()
            fragment.reset()

    def _drive_staggered_state_schedule(self) -> None:
        if not self._optimizer_state_sync_enabled:
            return
        if not self._state_fragments_per_fragment or not self._fragments:
            return
        fragment_idx = self._state_cursor
        self._sync_state_fragments(fragment_idx, limit_one=True)
        self._state_cursor = (self._state_cursor + 1) % len(self._fragments)

    def _log_optimizer_state_fragment_assignments(self) -> None:
        if self._optimizer_state_log_emitted:
            return

        if not any(self._state_fragments_per_fragment):
            logger.info("DES-LOC streaming optimizer state fragments: none discovered.")
            self._optimizer_state_log_emitted = True
            return

        per_state: dict[str, dict[str, int]] = defaultdict(dict)
        for fragments in self._state_fragments_per_fragment:
            for fragment in fragments:
                state_map = per_state[fragment.state_key]
                for name in fragment.parameter_names:
                    key = _component_key_from_name(name)
                    existing = state_map.get(key)
                    if existing is not None and existing != fragment.fragment_id:
                        logger.warning(
                            "DES-LOC streaming optimizer state '%s' component %s mapped to multiple fragments (%s, %s).",
                            fragment.state_key,
                            key,
                            existing,
                            fragment.fragment_id,
                        )
                    state_map[key] = fragment.fragment_id

        for state_key, mapping in sorted(per_state.items()):
            formatted = "; ".join(
                f"{component}->frag{fragment_id}"
                for component, fragment_id in sorted(mapping.items())
            )
            logger.info(
                "DES-LOC streaming optimizer state '%s' fragments: %s",
                state_key,
                formatted or "none",
            )
        self._optimizer_state_log_emitted = True


class DesLocFTOptimizersContainer(FTOptimizersContainer):
    """FT optimizer container augmented with DES-LOC synchronization."""

    def __init__(self, config: DesLocFTOptimizersConfig) -> None:
        desloc_config = config.desloc_config
        if desloc_config.param_sync_every <= 0:
            msg = "desloc.param_sync_every must be a positive integer."
            raise ValueError(msg)
        if desloc_config.low_rank_projector_source not in (
            "pseudo_grad",
            "full_rank_grad",
        ):
            msg = (
                "desloc.low_rank_projector_source must be 'pseudo_grad' or 'full_rank_grad'; "
                f"received {desloc_config.low_rank_projector_source!r}."
            )
            raise ValueError(msg)

        streaming_cfg = config.streaming or desloc_config.resolved_streaming()

        super().__init__(
            config.model_parts,
            config.optimizer_cls,
            config.optimizer_kwargs,
            config.ft_manager,
            use_ft_optimizer=config.use_ft_optimizer,
            param_groups=config.param_groups,
        )

        backup_device = desloc_config.resolved_backup_device()
        optimizer_sync = desloc_config.normalized_optimizer_sync()
        outer_optimizer_spec = (
            config.outer_optimizer or desloc_config.normalized_outer_optimizer()
        )

        self._desloc_controllers: list[
            DesLocController | StreamingDesLocController
        ] = []
        for idx, (model, optimizer) in enumerate(
            zip(self.model_parts, self.optimizers, strict=True)
        ):
            controller_config = DesLocControllerConfig(
                manager=config.ft_manager,
                model=model,
                optimizer=optimizer,
                param_sync_every=desloc_config.param_sync_every,
                optimizer_sync_every=optimizer_sync,
                backup_device=backup_device,
                pin_memory=desloc_config.pin_memory,
                name_prefix=f"desloc_{idx}",
                quorum_timeout_seconds=desloc_config.quorum_timeout_seconds,
                outer_optimizer=outer_optimizer_spec,
                log_outer_metrics=desloc_config.log_outer_metrics,
                metrics_logger=None,
                checkpoint_outer_optimizer=desloc_config.checkpoint_outer_optimizer,
                disable_optimizer_state_sync=desloc_config.disable_optimizer_state_sync,
                low_rank_server_update=desloc_config.low_rank_server_update,
                outer_optimizer_low_rank=desloc_config.low_rank_outer_optimizer,
                low_rank_projector_error_feedback=desloc_config.low_rank_projector_error_feedback,
                low_rank_projector_source=desloc_config.low_rank_projector_source,
                pseudo_grad_top_k=desloc_config.pseudo_grad_top_k,
            )
            if streaming_cfg is not None:
                controller = StreamingDesLocController(controller_config, streaming_cfg)
            else:
                controller = DesLocController(controller_config)
            self._desloc_controllers.append(controller)

    def close_desloc(self) -> None:
        """Detach any registered DES-LOC hooks from the wrapped optimizers."""
        for controller in self._desloc_controllers:
            controller.close()
        self._desloc_controllers.clear()

    def set_desloc_metrics_logger(
        self, logger_fn: Callable[[dict[str, float]], None] | None
    ) -> None:
        for controller in self._desloc_controllers:
            controller.set_metrics_logger(logger_fn)


@contextmanager
def desloc_semi_sync_context(
    _ft_manager: FTManager, optimizer: torch.optim.Optimizer
) -> Iterator[None]:
    """Context manager wiring DES-LOC into TorchFT semi-sync execution."""
    try:
        yield
    finally:
        close_hook = getattr(optimizer, "close_desloc", None)
        if callable(close_hook):
            close_hook()


_MODULE_PROXY.__dict__.update(globals())
