# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Metrics for FL experiments."""

from __future__ import annotations

import contextlib

import math
from collections.abc import Iterator, Mapping, Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any, Self, TYPE_CHECKING

import torch
from torch import Tensor
from torchmetrics import Metric

from torchtitan.components.metrics import MetricsProcessor
from torchtitan.distributed import utils as dist_utils
from torchtitan.experiments.fl.optimizers._metric_utils import (
    METRIC_COUNT_PREFIX,
    metric_count_key,
)
from torchtitan.experiments.fl.callbacks import Callback, CallbackSetupContext, CallbackStepContext, CallbackValidationContext
from torchtitan.experiments.fl.optimizers.galore import (
    FULL_PROJ,
    LEFT_PROJ,
    RIGHT_PROJ,
    REV_STD_PROJ,
    STD_PROJ,
    GaLore,
    _orthogonal_matrix,
)
from torchtitan.tools.logging import logger

if TYPE_CHECKING:
    from collections.abc import Callable
    from types import TracebackType

    from torch.distributed.device_mesh import DeviceMesh
    from torch.optim import Optimizer

    from torchtitan.experiments.fl.configs.config import (
        ActivationMonitorConfig,
        BetasMonitorConfig,
        GaLoreMomentumProjectionConfig,
        HyperparameterSwitchConfig,
        LRMonitorConfig,
        MetricsConfig,
        OptimizerMonitorConfig,
        VSMonitorConfig,
    )


ProjectionBasis = Tensor | list[Tensor]

SVD_PROJECTION = "svd"
COLUMN_PROJECTION = "columns"
RANDOM_PROJECTION = "random"

PROJECT_REINIT_MODE = "project"
ZERO_REINIT_MODE = "zero"
VALID_REINIT_MODES = {PROJECT_REINIT_MODE, ZERO_REINIT_MODE}


@dataclass(frozen=True)
class GaLoreMomentumProjectionParams:
    """Configuration payload for :class:`GaLoreMomentumProjectionCallback`."""

    enabled: bool
    steps: Sequence[int]
    target_ranks: Sequence[int]
    state_keys: Sequence[str]
    transform: str
    proj_type: str
    shared_source: str | None
    column_count: int | None
    random_seed: int | None
    random_std: float
    log_metrics: bool
    reinit_mode: str = PROJECT_REINIT_MODE


@dataclass(frozen=True)
class HyperparameterSwitchParams:
    """Configuration payload for :class:`HyperparameterSwitchCallback`."""

    enabled: bool
    steps: Sequence[int]
    new_vs: Sequence[float] | None
    new_betas: Sequence[float] | None
    reset_momenta: Sequence[str]
    log_metrics: bool


class AggregationType(Enum):
    """Types of metric aggregation."""

    L2_NORM = "l2_norm"
    MIN = "min"
    MAX = "max"
    ZERO_COUNT = "zero_count"
    MEAN = "mean"


class PureUnigramCrossEntropy(Metric):
    """TorchMetric that computes unigram cross entropy for LM targets.

    This metric accumulates the per-token cross entropy between the provided
    targets and a pre-computed unigram distribution. Ignored indices are
    excluded from both the loss and item count.
    """

    full_state_update = False

    def __init__(
        self,
        unigram_probabilities: Tensor,
        ignore_index: int = -100,
        *,
        dist_sync_on_step: bool = False,
    ) -> None:
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        if unigram_probabilities.dim() != 1:
            msg = "unigram_probabilities must be a 1D tensor."
            raise ValueError(msg)

        if torch.any(unigram_probabilities < 0):
            msg = "unigram_probabilities must contain non-negative values."
            raise ValueError(msg)

        if not torch.any(unigram_probabilities > 0):
            msg = "unigram_probabilities must include at least one positive value."
            raise ValueError(msg)

        prob_dtype = unigram_probabilities.dtype if unigram_probabilities.is_floating_point() else torch.float32
        self.ignore_index = ignore_index
        # Store as buffer so it moves with the metric across devices.
        self.register_buffer(
            "unigram_probabilities",
            unigram_probabilities.clone().detach().to(dtype=prob_dtype),
        )
        self.add_state(
            "sum_loss",
            default=torch.tensor(0.0, dtype=prob_dtype),
            dist_reduce_fx="sum",
        )
        self.add_state(
            "total_items",
            default=torch.tensor(0, dtype=torch.long),
            dist_reduce_fx="sum",
        )

    def update(self, output: Mapping | Tensor, target: Tensor) -> None:  # noqa: ARG002
        """Update the metric state with a batch of targets."""
        target = target.view(-1).to(torch.long)

        valid_mask = target != self.ignore_index
        if not torch.any(valid_mask):
            return

        target_filtered = target[valid_mask]
        vocab_size = self.unigram_probabilities.shape[0]
        in_vocab_mask = (target_filtered >= 0) & (target_filtered < vocab_size)
        if not torch.any(in_vocab_mask):
            return

        # Use the unigram probabilities corresponding to the valid targets.
        if self.unigram_probabilities.device != target.device:
            self.unigram_probabilities = self.unigram_probabilities.to(target.device)

        target_in_vocab = target_filtered[in_vocab_mask]
        selected_probs = self.unigram_probabilities[target_in_vocab]
        eps = torch.finfo(selected_probs.dtype).tiny
        losses = -torch.log(selected_probs.clamp_min(eps))

        loss_sum = losses.sum().to(self.sum_loss.device)
        self.sum_loss += loss_sum

        items = int(in_vocab_mask.sum().item())
        self.total_items += items

    def compute(self) -> Tensor:
        """Return the average unigram cross entropy across all updates."""
        if int(self.total_items.item()) == 0:
            return self.sum_loss.new_zeros(())
        total_items = self.total_items.to(self.sum_loss.dtype)
        return self.sum_loss / total_items


class UnigramMetricHandle:
    """Handle returned when registering a unigram metric with a manager."""

    def __init__(self, manager: UnigramMetricManager, metric: PureUnigramCrossEntropy) -> None:
        self._manager = manager
        self.metric = metric
        self._active = True

    def close(self) -> None:
        """Unregister the associated metric from the manager."""
        if not self._active:
            return
        self._manager.unregister(self.metric)
        self._active = False

    def __enter__(self) -> Self:
        """Return this handle to support context manager usage."""
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc: BaseException | None,
        traceback: TracebackType | None,
    ) -> bool:
        """Close the metric handle upon exiting the context manager."""
        self.close()
        # Do not suppress exceptions.
        return False


class UnigramMetricManager:
    """Track and aggregate unigram cross-entropy metrics for a training run."""

    def __init__(self) -> None:
        self._metrics: list[PureUnigramCrossEntropy] = []

    def register(self, metric: PureUnigramCrossEntropy, group_key: str | None = None) -> UnigramMetricHandle:
        """Register a metric and return a handle that can be closed to unregister it."""
        del group_key  # group key is currently informational only
        self._metrics.append(metric)
        return UnigramMetricHandle(self, metric)

    def unregister(self, metric: PureUnigramCrossEntropy) -> None:
        """Remove a metric from the registry if present."""
        with contextlib.suppress(ValueError):
            self._metrics.remove(metric)

    def collect(self, *, reset: bool = True) -> tuple[float, int]:
        """Return the total accumulated loss and token count across all metrics."""
        total_loss = 0.0
        total_items = 0
        for metric in self._metrics:
            items = int(metric.total_items.item())
            if items > 0:
                total_loss += float(metric.sum_loss.item())
                total_items += items

        if reset:
            for metric in self._metrics:
                metric.sum_loss.zero_()
                metric.total_items.zero_()

        return total_loss, total_items

    def reset(self) -> None:
        """Zero out the accumulators for all registered metrics."""
        for metric in self._metrics:
            metric.sum_loss.zero_()
            metric.total_items.zero_()

    def update(self, labels: Tensor) -> None:
        """Update all registered metrics with a batch of labels."""
        if not self._metrics:
            return
        for metric in self._metrics:
            metric.update({}, labels)

    def clear(self) -> None:
        """Remove all registered metrics."""
        self._metrics.clear()

    def has_metrics(self) -> bool:
        """Return ``True`` if any metrics are currently registered."""
        return bool(self._metrics)


_UNIGRAM_MANAGER_ATTR = "_fl_unigram_manager"


def get_or_create_unigram_manager(job_config: Any) -> UnigramMetricManager:
    """Retrieve or initialize the unigram metric manager stored on a job config."""
    manager = getattr(job_config, _UNIGRAM_MANAGER_ATTR, None)
    if manager is None:
        manager = UnigramMetricManager()
        setattr(job_config, _UNIGRAM_MANAGER_ATTR, manager)
    return manager


def compute_skewness(value: torch.Tensor) -> torch.Tensor:
    """Compute the skewness of a tensor.

    Args:
        value: Input tensor of shape (..., N).

    Returns:
        Tensor containing the skewness value.
    """
    mean = value.mean(dim=-1, keepdim=True)
    diffs = value - mean
    m_3 = torch.mean(torch.pow(diffs, 3), dim=-1)
    var = torch.mean(torch.pow(diffs, 2), dim=-1)
    eps = torch.finfo(var.dtype).eps if var.dtype.is_floating_point else 1e-12
    var = torch.clamp(var, min=eps)
    return (m_3 / (var * torch.sqrt(var))).mean()


def compute_kurtosis(value: torch.Tensor) -> torch.Tensor:
    """Compute the kurtosis of a tensor.

    Args:
        value: Input tensor of shape (..., N).

    Returns:
        Tensor containing the kurtosis value.
    """
    mean = value.mean(dim=-1, keepdim=True)
    diffs = value - mean
    m_4 = torch.mean(torch.pow(diffs, 4), dim=-1)
    var = torch.mean(torch.pow(diffs, 2), dim=-1)
    eps = torch.finfo(var.dtype).eps if var.dtype.is_floating_point else 1e-12
    var = torch.clamp(var, min=eps)
    return (m_4 / (var**2)).mean()


class ActivationMonitor(Callback):
    """Collects activation statistics across the full model.

    By default, only the following metrics are collected:
    - activations/average/max/full_model_input
    - activations/average/max/full_model_output
    - activations/average/median/full_model_input
    - activations/average/median/full_model_output
    - activations/l2_norm/full_model_input
    - activations/l2_norm/full_model_output
    - activations/max/full_model_input
    - activations/max/full_model_output
    """

    def __init__(
        self,
        *,
        interval: int = 25,
        ignore_module_types: Sequence[str] | None = None,
        gradient_accumulation_steps: int = 1,
        enabled_metrics: set[str] | None = None,
    ) -> None:
        self.interval = interval
        self.ignore_module_types = tuple(ignore_module_types) if ignore_module_types is not None else None
        self.gradient_accumulation_steps = max(1, gradient_accumulation_steps)

        # Default enabled metrics - only the essential ones
        if enabled_metrics is None:
            self.enabled_metrics = {
                "activations/average/max/full_model_input",
                "activations/average/max/full_model_output",
                "activations/average/median/full_model_input",
                "activations/average/median/full_model_output",
                "activations/l2_norm/full_model_input",
                "activations/l2_norm/full_model_output",
                "activations/max/full_model_input",
                "activations/max/full_model_output",
            }
        else:
            self.enabled_metrics = enabled_metrics

        self._handles: list[torch.utils.hooks.RemovableHandle] = []
        self._pre_handle: torch.utils.hooks.RemovableHandle | None = None
        self._module_names: dict[torch.nn.Module, str] = {}
        self._metrics: dict[str, float | list[float]] = {}
        self._collect_this_step = False
        self._microbatch_counter = 0
        self._device: torch.device | None = None
        self._registered = False
        self._model_ref: torch.nn.Module | None = None

    def setup(self, context: CallbackSetupContext) -> None:
        """Register activation hooks when the first model shard is available."""
        if not self.enabled or self._registered:
            return
        if not context.model_parts:
            return
        model = context.model_parts[0]
        self.register(model)
        self._model_ref = model

    def on_step_end(self, context: CallbackStepContext) -> None:
        """Finalize metrics at the end of each logging interval."""
        if not self.enabled or not self._registered:
            return
        self.finalize(context.step, context.logger, context.mesh)

    def close(self) -> None:
        """Remove hooks and release references when training ends."""
        if self._pre_handle is not None:
            self._pre_handle.remove()
            self._pre_handle = None
        for handle in self._handles:
            handle.remove()
        self._handles.clear()
        self._module_names.clear()
        self._metrics.clear()
        self._collect_this_step = False
        self._registered = False
        self._model_ref = None

    def _is_metric_enabled(self, metric_key: str) -> bool:
        """Check if a metric is enabled for collection."""
        return metric_key in self.enabled_metrics

    @property
    def enabled(self) -> bool:
        """Check if the monitor is enabled based on the interval.

        Returns:
            bool: True if the monitor is enabled, False otherwise.
        """
        return self.interval > 0

    def should_log_step(self, step: int) -> bool:
        """Determine if metrics should be logged at the current step.

        Args:
            step: Current training step.

        Returns:
            bool: True if metrics should be logged this step, False otherwise.
        """
        return self.enabled and step % self.interval == 0

    @property
    def is_registered(self) -> bool:
        """Check if hooks are registered.

        Returns:
            bool: True if hooks are registered, False otherwise.
        """
        return self._registered

    def register(self, model: torch.nn.Module) -> None:
        """Register forward hooks on the model to collect activations.

        Args:
            model: The model to register hooks on.
        """
        if not self.enabled or self._registered:
            return

        self._module_names = {module: name for name, module in model.named_modules()}
        self._pre_handle = model.register_forward_pre_hook(self._forward_pre_hook, with_kwargs=True)
        model.apply(self._register_forward_hook)
        self._registered = True

    def _register_forward_hook(self, module: torch.nn.Module) -> None:
        self._handles.append(module.register_forward_hook(self._forward_hook))

    def _forward_pre_hook(
        self,
        module: torch.nn.Module,
        args: tuple[Any, ...],
        kwargs: dict[str, Any],
    ) -> None:
        del module, args, kwargs
        # Pre-hook doesn't need to do anything; collection is controlled by finalize()

    def _forward_hook(
        self,
        module: torch.nn.Module,
        inputs: tuple[Any, ...],
        output: Any,
    ) -> None:
        if not self._collect_this_step:
            return

        module_name = self._module_names.get(module, "")
        if self.ignore_module_types is not None:
            lowered_name = module_name.lower()
            if any(ignore.lower() in lowered_name for ignore in self.ignore_module_types):
                return

        self._recursively_add_metrics("_input", inputs)
        self._recursively_add_metrics("_output", output)

    def _recursively_add_metrics(self, suffix: str, values: Any) -> None:
        if values is None:
            return
        if isinstance(values, (str, bytes)):
            return
        if isinstance(values, dict):
            for val in values.values():
                self._recursively_add_metrics(suffix, val)
            return
        if isinstance(values, torch.Tensor):
            self._add_metrics(suffix, values)
            return
        if isinstance(values, Sequence):
            for value in values:
                self._recursively_add_metrics(suffix, value)

    def _add_metrics(  # noqa: C901, PLR0912, PLR0915
        self, suffix: str, value: torch.Tensor
    ) -> None:
        if value.dtype == torch.bool:
            return
        if not (value.is_floating_point() or value.is_complex()):
            return

        with torch.no_grad():
            tensor = value.detach()
            if tensor.is_complex():
                tensor = tensor.real
            if self._device is None:
                self._device = tensor.device

            # Accumulate sum of squares for L2 norm (will sqrt after gathering)
            l2_key = f"activations/l2_norm/full_model{suffix}"
            if self._is_metric_enabled(l2_key):
                current_l2 = self._metrics.get(l2_key, 0.0)
                if isinstance(current_l2, float):
                    self._metrics[l2_key] = current_l2 + float(torch.sum(tensor**2).item())

            avg_key = f"activations/average/full_model{suffix}"
            # Check if any average metrics are enabled (max, min, or median)
            avg_max_key = f"activations/average/max/full_model{suffix}"
            avg_min_key = f"activations/average/min/full_model{suffix}"
            avg_median_key = f"activations/average/median/full_model{suffix}"
            if (
                self._is_metric_enabled(avg_max_key)
                or self._is_metric_enabled(avg_min_key)
                or self._is_metric_enabled(avg_median_key)
            ):
                avg_list = self._metrics.setdefault(avg_key, [])
                if isinstance(avg_list, list):
                    avg_list.append(float(tensor.mean().item()))

            if tensor.numel() == 0:
                return

            # Compute max over last dimension and take mean (consistent with reference)
            max_key = f"activations/max/full_model{suffix}"
            if self._is_metric_enabled(max_key):
                if tensor.ndim >= 1 and tensor.shape[-1] > 0:
                    max_value = tensor.max(dim=-1).values.mean().item()
                else:
                    max_value = tensor.max().item()
                max_list = self._metrics.setdefault(max_key, [])
                if isinstance(max_list, list):
                    max_list.append(float(max_value))

            # Check if skewness or kurtosis metrics are enabled
            if tensor.ndim >= 1 and tensor.shape[-1] > 0:
                skew_max_key = f"activations/skewness/max/full_model{suffix}"
                skew_min_key = f"activations/skewness/min/full_model{suffix}"
                skew_median_key = f"activations/skewness/median/full_model{suffix}"
                kurt_max_key = f"activations/kurtosis/max/full_model{suffix}"
                kurt_min_key = f"activations/kurtosis/min/full_model{suffix}"
                kurt_median_key = f"activations/kurtosis/median/full_model{suffix}"

                need_skewness = (
                    self._is_metric_enabled(skew_max_key)
                    or self._is_metric_enabled(skew_min_key)
                    or self._is_metric_enabled(skew_median_key)
                )
                need_kurtosis = (
                    self._is_metric_enabled(kurt_max_key)
                    or self._is_metric_enabled(kurt_min_key)
                    or self._is_metric_enabled(kurt_median_key)
                )

                if need_skewness or need_kurtosis:
                    skew_key = f"activations/skewness/full_model{suffix}"
                    kurt_key = f"activations/kurtosis/full_model{suffix}"

                    if need_skewness:
                        skewness = compute_skewness(tensor)
                        skew_list = self._metrics.setdefault(skew_key, [])
                        if isinstance(skew_list, list):
                            skew_list.append(float(skewness.item()))

                    if need_kurtosis:
                        kurtosis = compute_kurtosis(tensor)
                        kurt_list = self._metrics.setdefault(kurt_key, [])
                        if isinstance(kurt_list, list):
                            kurt_list.append(float(kurtosis.item()))

    def finalize(
        self,
        step: int,
        logger: Any,
        mesh: DeviceMesh | None,
    ) -> None:
        """Finalize metric collection for the current step.

        Args:
            step: Current training step.
            logger: Logger to log metrics.
            mesh: Device mesh for distributed reduction.
        """
        if not self.enabled or not self._registered:
            return

        # If this IS a logging step, log the metrics we collected during this step
        if self.should_log_step(step) and self._metrics:
            metrics = self._prepare_local_metrics()
            if metrics:
                reduced_metrics = self._reduce_metrics(metrics, mesh)
                if reduced_metrics:
                    logger.log(reduced_metrics, step)

        # Prepare for next step: enable collection if next step should be logged
        next_step = step + 1
        if self.should_log_step(next_step):
            self._reset_metrics()
            self._collect_this_step = True
        else:
            self._collect_this_step = False
            self._reset_metrics()

    def _prepare_local_metrics(self) -> dict[str, float | list[float]]:  # noqa: C901
        prepared: dict[str, float | list[float]] = {}
        for suffix in ("_input", "_output"):
            l2_key = f"activations/l2_norm/full_model{suffix}"
            if l2_key in self._metrics and self._is_metric_enabled(l2_key):
                l2_val = self._metrics[l2_key]
                if isinstance(l2_val, float):
                    prepared[l2_key] = l2_val

            max_key = f"activations/max/full_model{suffix}"
            if self._is_metric_enabled(max_key):
                max_vals = self._metrics.get(max_key)
                if max_vals and isinstance(max_vals, list):
                    prepared[max_key] = float(max(max_vals))

            for metric_name in ("average", "skewness", "kurtosis"):
                key = f"activations/{metric_name}/full_model{suffix}"
                values = self._metrics.get(key)
                if not values or not isinstance(values, list):
                    continue
                tensor_values = torch.tensor(values)

                max_metric_key = f"activations/{metric_name}/max/full_model{suffix}"
                if self._is_metric_enabled(max_metric_key):
                    prepared[max_metric_key] = float(tensor_values.max().item())

                min_metric_key = f"activations/{metric_name}/min/full_model{suffix}"
                if self._is_metric_enabled(min_metric_key):
                    prepared[min_metric_key] = float(tensor_values.min().item())

                median_metric_key = f"activations/{metric_name}/median/full_model{suffix}"
                if self._is_metric_enabled(median_metric_key):
                    prepared[median_metric_key] = values

        return prepared

    def _reduce_metrics(  # noqa: C901, PLR0912
        self, metrics: dict[str, float | list[float]], mesh: DeviceMesh | None
    ) -> dict[str, float]:
        reduced: dict[str, float] = {}

        # Handle single-rank or no mesh case
        if mesh is None:
            for key, value in metrics.items():
                if "l2_norm" in key:
                    if isinstance(value, list):
                        if not value:
                            continue
                        # Compute sqrt of sum of squares for L2 norm
                        reduced[key] = math.sqrt(sum(x**2 for x in value))
                    else:
                        reduced[key] = math.sqrt(value)
                elif isinstance(value, list):
                    if not value:
                        continue
                    # Use statistics.median for efficiency (no tensor conversion)
                    sorted_values = sorted(value)
                    n = len(sorted_values)
                    if n % 2 == 0:
                        reduced[key] = (sorted_values[n // 2 - 1] + sorted_values[n // 2]) / 2
                    else:
                        reduced[key] = sorted_values[n // 2]
                else:
                    reduced[key] = value
            return reduced

        device = self._device or torch.device("cpu")

        for key, value in metrics.items():
            if isinstance(value, list):
                if not value:
                    continue
                # Convert to tensor and compute median locally
                local_median = torch.tensor(value, device=device).median()
                # Take max of medians across ranks (for consistency with reference)
                reduced[key] = dist_utils.dist_max(local_median, mesh)
                continue

            # Create tensor once and reuse
            tensor_value = torch.tensor(value, device=device)
            if "l2_norm" in key:
                # For L2 norm: sum across ranks then sqrt
                # Note: value is already squared locally
                reduced[key] = math.sqrt(dist_utils.dist_sum(tensor_value, mesh))
            elif "max" in key:
                reduced[key] = dist_utils.dist_max(tensor_value, mesh)
            elif "min" in key:
                reduced[key] = -dist_utils.dist_max(-tensor_value, mesh)
            else:
                # Default to mean for other metrics
                reduced[key] = dist_utils.dist_mean(tensor_value, mesh)

        return reduced

    def _reset_metrics(self) -> None:
        self._metrics = {}


class OptimizerMonitor(Callback):
    """Compute and log the L2 norm of gradients.

    Args:
        interval: How often to log metrics (every N steps).
        only_global: Whether to only log global metrics.
        log_optimizer_metrics: Whether to log optimizer-specific metrics.
    """

    def __init__(
        self,
        interval: int = 10,
        *,
        only_global: bool = True,
        log_optimizer_metrics: bool = True,
    ) -> None:
        self.log_optimizer_metrics = log_optimizer_metrics
        self.only_global = only_global
        self.interval = interval
        self._model_ref: torch.nn.Module | None = None

    def setup(self, context: CallbackSetupContext) -> None:
        """Store a reference to the first model shard for later logging."""
        if context.model_parts:
            self._model_ref = context.model_parts[0]

    def _reduce_metrics_across_ranks(
        self, optimizer_metrics: dict[str, torch.Tensor], mesh: DeviceMesh
    ) -> dict[str, float]:
        """Reduce optimizer metrics across all ranks using TorchTitan's distributed utilities.

        Follows the pattern from torchtitan.distributed.utils._dist_reduce.
        """
        reduced_metrics = {}

        for metric_name, metric_value in list(optimizer_metrics.items()):
            if metric_name.startswith(METRIC_COUNT_PREFIX):
                continue

            count_key = metric_count_key(metric_name)
            count_value = optimizer_metrics.pop(count_key, None)
            if not isinstance(metric_value, torch.Tensor):
                # Skip non-tensor metrics
                reduced_metrics[metric_name] = metric_value
                continue

            if isinstance(count_value, torch.Tensor):
                count_tensor = count_value.to(device=metric_value.device, dtype=metric_value.dtype)
            elif isinstance(count_value, (int, float)):
                count_tensor = torch.tensor(
                    count_value,
                    device=metric_value.device,
                    dtype=metric_value.dtype,
                )
            else:
                count_tensor = torch.tensor(
                    1.0,
                    device=metric_value.device,
                    dtype=metric_value.dtype,
                )

            # Determine reduction operation based on metric name
            if "l2_norm" in metric_name or "norm" in metric_name:
                # For L2 norms, the values are already squared by pre_reduce_metrics
                # Sum across ranks then sqrt (dist_utils returns float)
                sum_squared = dist_utils.dist_sum(metric_value, mesh)
                replica_count = dist_utils.dist_sum(count_tensor, mesh)
                denom = max(replica_count, 1.0)
                reduced_metrics[metric_name] = math.sqrt(sum_squared / denom)
            elif "max" in metric_name:
                reduced_metrics[metric_name] = dist_utils.dist_max(metric_value, mesh)
            elif "min" in metric_name:
                # dist_min not implemented, use -dist_max(-x)
                reduced_metrics[metric_name] = -dist_utils.dist_max(-metric_value, mesh)
            elif "mean" in metric_name or "avg" in metric_name:
                total = dist_utils.dist_sum(metric_value, mesh)
                replica_count = dist_utils.dist_sum(count_tensor, mesh)
                denom = max(replica_count, 1.0)
                reduced_metrics[metric_name] = total / denom
            elif "zero_count" in metric_name:
                reduced_metrics[metric_name] = dist_utils.dist_sum(metric_value, mesh)
            else:
                # Default to sum for other metrics
                reduced_metrics[metric_name] = dist_utils.dist_sum(metric_value, mesh)

        # Remove any auxiliary replica count keys that may remain.
        for key in list(optimizer_metrics.keys()):
            if key.startswith(METRIC_COUNT_PREFIX):
                optimizer_metrics.pop(key)

        return reduced_metrics

    def on_step_end(  # noqa: C901, PLR0912, PLR0915
        self,
        context: CallbackStepContext,
    ) -> None:
        """Calculate the statistics at the end of the batch."""
        optimizers = context.optimizers
        if optimizers is None:
            return
        if optimizers.optimizers is None:
            return
        if len(optimizers.optimizers) == 0:
            return

        model = context.model_parts[0] if context.model_parts else self._model_ref
        if model is None:
            return

        step = context.step
        mesh = context.mesh
        logger = context.logger

        # Early exit if monitoring is disabled (interval <= 0)
        if self.interval <= 0:
            return

        if step % self.interval != 0:
            return

        optimizer_metrics: dict = {}
        optimizer = optimizers.optimizers[0]

        for name, p in model.named_parameters():
            if p.grad is not None and p.requires_grad:
                metric_reporter: Callable[[Any, Any, Any], dict] | None = getattr(
                    optimizer,
                    "report_per_parameter_metrics",
                    None,
                )
                if callable(metric_reporter) and self.log_optimizer_metrics:
                    optimizer_metrics.update(
                        metric_reporter(p, name, optimizer_metrics),
                    )

        if mesh is not None and self.log_optimizer_metrics:
            # Pre-process metrics before reduction
            pre_reduce_metrics: Callable[[Any]] | None = getattr(
                optimizer,
                "pre_reduce_metrics",
                None,
            )
            if callable(pre_reduce_metrics):
                optimizer_metrics = pre_reduce_metrics(optimizer_metrics)

            # Reduce metrics across all ranks using TorchTitan's distributed utilities
            optimizer_metrics = self._reduce_metrics_across_ranks(optimizer_metrics, mesh)

        # Dynamically aggregate all metric names found in optimizer_metrics
        agg_type_values = {agg_type.value for agg_type in AggregationType}
        agg_dict: dict[tuple[AggregationType, str], float] = {}
        agg_counts: dict[tuple[AggregationType, str], int] = {}
        metric_parts_required = 2

        # Initialize aggregation dictionary
        for metric in optimizer_metrics:
            parts = metric.split("/")
            if len(parts) < metric_parts_required:
                continue
            agg_type_str, metric_name = parts[0], parts[1]

            if agg_type_str not in agg_type_values:
                msg = f"Unknown aggregation type: {agg_type_str}"
                raise ValueError(msg)

            # Get the corresponding enum member
            agg_type = AggregationType(agg_type_str)
            key = (agg_type, metric_name)

            if key not in agg_dict:
                # Pattern match on aggregation type to get initial value
                match agg_type:
                    case AggregationType.L2_NORM:
                        agg_dict[key] = 0.0
                    case AggregationType.MIN:
                        agg_dict[key] = float("inf")
                    case AggregationType.MAX:
                        agg_dict[key] = float("-inf")
                    case AggregationType.ZERO_COUNT:
                        agg_dict[key] = 0.0
                    case AggregationType.MEAN:
                        agg_dict[key] = 0.0
                        agg_counts[key] = 0

        # Aggregate metrics
        for metric in optimizer_metrics:
            parts = metric.split("/")
            if len(parts) < metric_parts_required:
                continue
            agg_type_str, metric_name = parts[0], parts[1]
            if agg_type_str not in agg_type_values:
                continue

            agg_type = AggregationType(agg_type_str)
            key = (agg_type, metric_name)
            value = optimizer_metrics[metric]
            if isinstance(value, torch.Tensor):
                value = value.item()

            # Pattern match on aggregation type to perform aggregation
            match agg_type:
                case AggregationType.L2_NORM:
                    agg_dict[key] += value**2
                case AggregationType.MIN:
                    agg_dict[key] = min(agg_dict[key], value)
                case AggregationType.MAX:
                    agg_dict[key] = max(agg_dict[key], value)
                case AggregationType.ZERO_COUNT:
                    agg_dict[key] += value
                case AggregationType.MEAN:
                    agg_dict[key] += value
                    agg_counts[key] = agg_counts.get(key, 0) + 1

        # Report all aggregated metrics as agg_type/metric_name/global
        for (agg_type, metric_name), agg_value in agg_dict.items():
            # Pattern match on aggregation type to finalize value
            match agg_type:
                case AggregationType.L2_NORM:
                    final_value = agg_value**0.5
                case AggregationType.MIN | AggregationType.MAX:
                    final_value = agg_value
                case AggregationType.ZERO_COUNT:
                    final_value = agg_value
                case AggregationType.MEAN:
                    count = agg_counts.get((agg_type, metric_name), 0)
                    final_value = agg_value / count if count > 0 else float("nan")

            optimizer_metrics[f"{agg_type.value}/{metric_name}/global"] = final_value

        # If only_global is set, remove all non-global metrics
        if self.only_global:
            optimizer_metrics = {k: v for k, v in optimizer_metrics.items() if k.endswith("/global")}

        # Convert any remaining tensors to floats (shouldn't be any after reduction, but just in case)
        optimizer_metrics = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in optimizer_metrics.items()}
        logger.log(optimizer_metrics, step)


class LRMonitor(Callback):
    """Logs the learning rate of each optimizer parameter group."""

    def __init__(self, *, interval: int = 1, enabled: bool = True) -> None:
        self.interval = interval
        self.enabled = enabled

    def on_step_end(self, context: CallbackStepContext) -> None:
        """Log the current learning rates when the interval is reached."""
        if not self.enabled or self.interval <= 0:
            return
        if context.optimizers is None:
            return
        if context.step % self.interval != 0:
            return

        metrics: dict[str, float] = {}
        for optimizer in context.optimizers:
            name = optimizer.__class__.__name__
            for idx, group in enumerate(optimizer.param_groups):
                lr = group.get("lr")
                if lr is None:
                    continue
                metrics[f"lr-{name}/group{idx}"] = float(lr)

        if metrics:
            context.logger.log(metrics, context.step)


class BetasMonitor(Callback):
    """Logs optimizer beta hyperparameters and epsilon values."""

    def __init__(self, *, interval: int = 0, enabled: bool = False) -> None:
        self.interval = interval
        self.enabled = enabled

    def on_step_end(self, context: CallbackStepContext) -> None:
        """Record beta hyperparameters for each optimizer group when enabled."""
        if not self._should_log(context):
            return

        metrics = dict(self._collect_metrics(context.optimizers))
        if metrics:
            context.logger.log(metrics, context.step)

    def _should_log(self, context: CallbackStepContext) -> bool:
        """Return ``True`` when beta metrics should be emitted for the step."""
        if not self.enabled or self.interval <= 0:
            return False
        if context.optimizers is None:
            return False
        return context.step % self.interval == 0

    def _collect_metrics(self, optimizers: Sequence[torch.optim.Optimizer]) -> Iterator[tuple[str, float]]:
        for optimizer, name, group_idx, group in self._iter_param_groups(optimizers):
            betas = group.get("betas")
            if betas is not None:
                for beta_idx, beta_value in enumerate(self._iter_values(betas), start=1):
                    yield (
                        f"beta{beta_idx}-{name}/group{group_idx}",
                        self._as_float(beta_value),
                    )

            epsilon = self._get_epsilon(group)
            if epsilon is not None:
                yield (f"eps-{name}/group{group_idx}", epsilon)

            inner_step = self._get_inner_step(optimizer)
            if inner_step is not None:
                yield (f"inner_step-{name}/group{group_idx}", inner_step)

    def _iter_param_groups(
        self, optimizers: Sequence[torch.optim.Optimizer]
    ) -> Iterator[tuple[torch.optim.Optimizer, str, int, dict[str, Any]]]:
        for optimizer in optimizers:
            name = optimizer.__class__.__name__
            for group_idx, group in enumerate(optimizer.param_groups):
                yield optimizer, name, group_idx, group

    def _iter_values(self, betas: Any) -> Iterator[Any]:
        if isinstance(betas, Sequence) and not isinstance(betas, (str, bytes)):
            yield from betas
        else:
            yield betas

    def _get_epsilon(self, group: Mapping[str, Any]) -> float | None:
        eps_value = group.get("eps")
        if eps_value is None:
            eps_value = group.get("epsilon")
        if eps_value is None:
            return None
        return self._as_float(eps_value)

    def _get_inner_step(self, optimizer: torch.optim.Optimizer) -> float | None:
        step_count = getattr(optimizer, "_step_count", None)
        if step_count is None:
            return None
        return self._as_float(step_count)

    def _as_float(self, value: Any) -> float:
        if isinstance(value, torch.Tensor):
            return float(value.detach().item())
        return float(value)


class VSMonitor(Callback):
    """Logs quasi-hyperbolic v parameters if available."""

    def __init__(self, *, interval: int = 0, enabled: bool = False) -> None:
        self.interval = interval
        self.enabled = enabled

    def on_step_end(self, context: CallbackStepContext) -> None:
        """Log quasi-hyperbolic ``v`` parameters at the configured cadence."""
        if not self.enabled or self.interval <= 0:
            return
        if context.optimizers is None:
            return
        if context.step % self.interval != 0:
            return

        metrics: dict[str, float] = {}
        for optimizer in context.optimizers:
            name = optimizer.__class__.__name__
            for idx, group in enumerate(optimizer.param_groups):
                vs = group.get("vs")
                if vs is None:
                    continue
                v_values = list(vs) if isinstance(vs, Sequence) and not isinstance(vs, (str, bytes)) else [vs]
                for v_idx, v_value in enumerate(v_values):
                    v_scalar = float(v_value.detach().item()) if isinstance(v_value, torch.Tensor) else float(v_value)
                    metrics[f"v{v_idx}-{name}/group{idx}"] = v_scalar

        if metrics:
            context.logger.log(metrics, context.step)


def _resolve_galore_proj_type(param: Tensor, proj_type: str) -> str:
    """Resolve the projection orientation for a GaLore parameter."""
    if proj_type in {STD_PROJ, REV_STD_PROJ}:
        if param.ndim < 2:
            return RIGHT_PROJ if proj_type == STD_PROJ else LEFT_PROJ
        if param.shape[0] >= param.shape[1]:
            return RIGHT_PROJ if proj_type == STD_PROJ else LEFT_PROJ
        return LEFT_PROJ if proj_type == STD_PROJ else RIGHT_PROJ
    return proj_type


def _clamp_galore_rank(params: Sequence[Tensor], target_rank: int) -> int:
    """Clamp the requested rank to the smallest supported dimension."""
    if target_rank <= 0:
        msg = "GaLore projection rank must be positive."
        raise ValueError(msg)

    if not params:
        return target_rank

    min_dim = min(
        int(param.numel()) if param.ndim <= 1 else int(min(param.shape))
        for param in params
    )
    min_dim = max(1, min_dim)
    if target_rank > min_dim:
        logger.warning(
            "Requested GaLore projection rank %s exceeds parameter dimensions; using %s instead.",
            target_rank,
            min_dim,
        )
    return min(target_rank, min_dim)


def _build_column_basis(tensor: Tensor, rank: int, proj_type: str, column_count: int | None) -> ProjectionBasis:
    """Construct a deterministic column/row selector basis."""
    count = column_count if column_count is not None else rank
    device = tensor.device
    dtype = tensor.dtype

    if proj_type == RIGHT_PROJ:
        return torch.eye(tensor.shape[1], device=device, dtype=dtype)[:count, :]
    if proj_type == LEFT_PROJ:
        return torch.eye(tensor.shape[0], device=device, dtype=dtype)[:, :count]
    if proj_type == FULL_PROJ:
        left = torch.eye(tensor.shape[0], device=device, dtype=dtype)[:, :count]
        right = torch.eye(tensor.shape[1], device=device, dtype=dtype)[:count, :]
        return [left, right]
    msg = f"Unsupported projection type {proj_type!r} for column projection."
    raise ValueError(msg)


def _build_random_basis(
    tensor: Tensor,
    rank: int,
    proj_type: str,
    *,
    random_std: float,
    generator: torch.Generator | None,
) -> ProjectionBasis:
    """Construct an orthonormal random projection basis."""
    device = tensor.device
    dtype = tensor.dtype
    if proj_type == RIGHT_PROJ:
        base = torch.randn(
            (tensor.shape[1], rank),
            device=device,
            dtype=dtype,
            generator=generator,
        )
        q_matrix, _ = torch.linalg.qr(base, mode="reduced")
        return q_matrix.T * random_std
    if proj_type == LEFT_PROJ:
        base = torch.randn(
            (tensor.shape[0], rank),
            device=device,
            dtype=dtype,
            generator=generator,
        )
        q_matrix, _ = torch.linalg.qr(base, mode="reduced")
        return q_matrix * random_std
    if proj_type == FULL_PROJ:
        left_base = torch.randn(
            (tensor.shape[0], rank),
            device=device,
            dtype=dtype,
            generator=generator,
        )
        right_base = torch.randn(
            (tensor.shape[1], rank),
            device=device,
            dtype=dtype,
            generator=generator,
        )
        left_q, _ = torch.linalg.qr(left_base, mode="reduced")
        right_q, _ = torch.linalg.qr(right_base, mode="reduced")
        return [left_q * random_std, right_q.T * random_std]
    msg = f"Unsupported projection type {proj_type!r} for random projection."
    raise ValueError(msg)


def _build_projection_basis(
    tensor: Tensor,
    rank: int,
    proj_type: str,
    transform: str,
    *,
    column_count: int | None,
    random_std: float,
    generator: torch.Generator | None,
) -> ProjectionBasis:
    """Return the projection basis for the configured transform."""
    if transform == SVD_PROJECTION:
        return _orthogonal_matrix(tensor, rank, proj_type)
    if transform == COLUMN_PROJECTION:
        return _build_column_basis(tensor, rank, proj_type, column_count)
    if transform == RANDOM_PROJECTION:
        return _build_random_basis(
            tensor,
            rank,
            proj_type,
            random_std=random_std,
            generator=generator,
        )
    msg = f"Unknown GaLore momentum projection transform {transform!r}."
    raise ValueError(msg)


def _apply_projection(tensor: Tensor, basis: ProjectionBasis, proj_type: str) -> Tensor:
    """Project ``tensor`` with the provided basis and return the reduced tensor."""
    working = tensor.float()
    original_dtype = tensor.dtype

    if isinstance(basis, list):
        left_basis, right_basis = basis
        projected = (
            left_basis.T.to(device=working.device, dtype=working.dtype)
            @ working
            @ right_basis.T.to(device=working.device, dtype=working.dtype)
        )
    elif proj_type == RIGHT_PROJ:
        projected = working @ basis.to(device=working.device, dtype=working.dtype).T
    elif proj_type == LEFT_PROJ:
        projected = basis.to(device=working.device, dtype=working.dtype).T @ working
    elif proj_type == FULL_PROJ:
        msg = "Full projection requires a pair of bases."
        raise ValueError(msg)
    else:
        msg = f"Unknown projection type {proj_type!r}."
        raise ValueError(msg)
    return projected.to(device=tensor.device, dtype=original_dtype)


_SECOND_MOMENT_HINTS = ("second_moment",)


def _is_second_moment_key(key: str) -> bool:
    """Heuristically detect optimizer second-moment buffers (e.g., ``exp_avg_sq``)."""
    if key.endswith("_sq"):
        return True
    return any(hint in key for hint in _SECOND_MOMENT_HINTS)


def _apply_squared_projection(tensor: Tensor, basis: ProjectionBasis, proj_type: str) -> Tensor:
    """Apply a projection where the basis weights are squared element-wise."""
    working = tensor.float()

    if isinstance(basis, list):
        left_basis, right_basis = basis
        temp = _apply_squared_projection(working, left_basis, LEFT_PROJ)
        return _apply_squared_projection(temp, right_basis, RIGHT_PROJ)

    weights = basis.to(device=working.device, dtype=working.dtype).pow(2)

    if proj_type == RIGHT_PROJ:
        return torch.einsum("...i,ji->...j", working, weights)
    if proj_type == LEFT_PROJ:
        return torch.einsum("ij,i...->j...", weights, working)
    msg = "Positive projection only supports left/right GaLore projections."
    raise ValueError(msg)


def _apply_positive_projection(
    tensor: Tensor,
    basis: ProjectionBasis,
    proj_type: str,
    *,
    reference_moment: Tensor | None = None,
    beta2: float | None = None,
) -> Tensor:
    r"""Project ``tensor`` while preserving non-negativity.

    When ``reference_moment`` and ``beta2`` are provided, apply the variance
    preserving rule:

    .. math::

        v_{t-1/2} = (1-\beta_2) (U^\top U_{t-1})^2 (\hat v_{t-1} - \hat m_{t-1}^2)
        + (U^\top U_{t-1} \hat m_{t-1})^2,

    followed by clipping to keep the estimate non-negative. Otherwise, fall
    back to the legacy squared-basis projection.
    """
    working = tensor.float()
    original_dtype = tensor.dtype

    if reference_moment is None or beta2 is None:
        projected = _apply_squared_projection(working, basis, proj_type)
        return projected.to(device=tensor.device, dtype=original_dtype)

    mean_reference = reference_moment.to(device=working.device, dtype=working.dtype)
    projected_mean = _apply_projection(mean_reference, basis, proj_type).to(dtype=working.dtype)
    centered_variance = torch.clamp(working - mean_reference.pow(2), min=0.0)
    projected_centered = _apply_squared_projection(centered_variance, basis, proj_type)

    beta_scale = 1.0 - float(beta2)
    combined = torch.abs(projected_centered.mul(beta_scale) + projected_mean.pow(2))
    return combined.to(device=tensor.device, dtype=original_dtype)


class GaLoreMomentumProjectionCallback(Callback):
    """Project GaLore optimizer momenta to a new rank at configured steps."""

    def __init__(self, params: GaLoreMomentumProjectionParams) -> None:
        ranks = tuple(int(rank) for rank in params.target_ranks)
        steps = tuple(int(step) for step in params.steps)
        self.step_ranks = self._build_step_rank_schedule(steps, ranks)
        self.enabled = params.enabled and bool(self.step_ranks)
        self.state_keys = tuple(params.state_keys)
        self.transform = params.transform.lower()
        self.proj_type = params.proj_type
        self.shared_source = params.shared_source
        self.column_count = params.column_count
        self.random_seed = params.random_seed
        self.random_std = params.random_std
        self.log_metrics = params.log_metrics
        self.reinit_mode = params.reinit_mode.lower()
        if self.reinit_mode not in VALID_REINIT_MODES:
            msg = f"Unknown GaLore momentum reinit mode {self.reinit_mode!r}."
            raise ValueError(msg)
        self._applied_steps: set[int] = set()

    def on_step_end(self, context: CallbackStepContext) -> None:
        """Run the momentum projection when its scheduled step is reached."""
        if not self.enabled:
            return
        target_rank = self.step_ranks.get(context.step)
        if target_rank is None or context.step in self._applied_steps:
            return

        optimizers = context.optimizers
        if optimizers is None:
            return

        generator = self._build_generator(context.step, optimizers)
        applied = False
        for optimizer in optimizers:
            if isinstance(optimizer, GaLore):
                self._apply_to_optimizer(optimizer, target_rank, generator)
                applied = True

        if applied:
            self._applied_steps.add(context.step)
            if self.log_metrics:
                context.logger.log({"galore_projection/rank": float(target_rank)}, context.step)
            logger.info("GaLore momentum projection applied at step %s to rank %s.", context.step, target_rank)

    def _build_step_rank_schedule(self, steps: Sequence[int], ranks: Sequence[int]) -> dict[int, int]:
        """Create a mapping of training step to target rank."""
        if not steps or not ranks:
            return {}
        if len(ranks) not in {1, len(steps)}:
            msg = "GaLore momentum projection ranks must be length 1 or match the number of steps."
            raise ValueError(msg)
        rank_values = ranks if len(ranks) == len(steps) else (ranks * len(steps))
        if any(rank <= 0 for rank in rank_values):
            msg = "GaLore momentum projection ranks must all be positive."
            raise ValueError(msg)
        return {step: rank_values[idx] for idx, step in enumerate(steps) if step >= 0}

    def _build_generator(
        self,
        step: int,
        optimizers: Sequence[torch.optim.Optimizer],
    ) -> torch.Generator | None:
        """Return a per-step generator for random projections."""
        if self.transform != RANDOM_PROJECTION:
            return None

        device = torch.device("cpu")
        for optimizer in optimizers:
            for group in optimizer.param_groups:
                params = group.get("params", [])
                if params:
                    first_param = params[0]
                    if isinstance(first_param, torch.Tensor):
                        device = first_param.device
                    break
        generator = torch.Generator(device=device)
        if self.random_seed is not None:
            generator.manual_seed(self.random_seed + step)
        return generator

    def _apply_to_optimizer(
        self,
        optimizer: GaLore,
        target_rank: int,
        generator: torch.Generator | None,
    ) -> None:
        """Update GaLore optimizer groups and project their momenta."""
        for group in optimizer.param_groups:
            initial_rank = group.get("rank")
            has_projector = any(
                "projector_meta" in optimizer.state.get(param, {})
                for param in group.get("params", [])
                if isinstance(param, torch.Tensor)
            )
            use_low_rank = initial_rank is not None or has_projector
            if not use_low_rank:
                continue

            params = [param for param in group.get("params", []) if isinstance(param, torch.Tensor)]
            rank = _clamp_galore_rank(params, target_rank)
            group["rank"] = rank
            for param in params:
                if self.reinit_mode == ZERO_REINIT_MODE:
                    self._reset_param_state(
                        optimizer=optimizer,
                        param=param,
                        group=group,
                        rank=rank,
                    )
                else:
                    self._project_param_state(
                        optimizer=optimizer,
                        param=param,
                        group=group,
                        rank=rank,
                        generator=generator,
                    )

    def _project_param_state(
        self,
        optimizer: GaLore,
        param: Tensor,
        group: Mapping[str, Any],
        rank: int,
        generator: torch.Generator | None,
    ) -> None:
        """Project configured momentum buffers for a single parameter."""
        state = optimizer.state[param]
        has_targets = any(isinstance(state.get(key), Tensor) for key in self.state_keys)
        if self.shared_source is not None:
            has_targets = has_targets or isinstance(state.get(self.shared_source), Tensor)
        if not has_targets:
            return

        resolved_proj_type = _resolve_galore_proj_type(param, group.get("proj_type", STD_PROJ))
        state_basis: ProjectionBasis | None = None
        betas = group.get("betas")
        beta2: float | None = None
        if isinstance(betas, Sequence) and len(betas) >= 2:
            beta2 = float(betas[1])

        first_moment_key = self.shared_source
        if first_moment_key is None:
            first_moment_key = next((key for key in self.state_keys if not _is_second_moment_key(key)), None)
        original_first_moment: Tensor | None = None
        if first_moment_key is not None:
            source_tensor = state.get(first_moment_key)
            if isinstance(source_tensor, Tensor):
                original_first_moment = source_tensor.detach().clone()

        if self.shared_source is not None:
            source_tensor = state.get(self.shared_source)
            if isinstance(source_tensor, Tensor):
                state_basis = _build_projection_basis(
                    source_tensor,
                    rank,
                    resolved_proj_type,
                    self.transform,
                    column_count=self.column_count,
                    random_std=self.random_std,
                    generator=generator,
                )
                state[self.shared_source] = _apply_projection(source_tensor, state_basis, resolved_proj_type)

        for key in self.state_keys:
            tensor = state.get(key)
            self._project_state_entry(
                state=state,
                key=key,
                tensor=tensor,
                resolved_proj_type=resolved_proj_type,
                shared_basis=state_basis,
                rank=rank,
                generator=generator,
                reference_moment=original_first_moment,
                beta2=beta2,
            )

        self._update_projector_meta(state, group, rank)

    def _project_state_entry(
        self,
        *,
        state: dict[str, Any],
        key: str,
        tensor: Tensor | None,
        resolved_proj_type: str,
        shared_basis: ProjectionBasis | None,
        rank: int,
        generator: torch.Generator | None,
        reference_moment: Tensor | None,
        beta2: float | None,
    ) -> None:
        if not isinstance(tensor, Tensor):
            return
        if self.shared_source is not None and key == self.shared_source and shared_basis is not None:
            return

        basis = shared_basis
        if basis is None:
            basis = _build_projection_basis(
                tensor,
                rank,
                resolved_proj_type,
                self.transform,
                column_count=self.column_count,
                random_std=self.random_std,
                generator=generator,
            )

        use_positive_projection = self.transform == SVD_PROJECTION and _is_second_moment_key(key)
        if use_positive_projection:
            state[key] = _apply_positive_projection(
                tensor,
                basis,
                resolved_proj_type,
                reference_moment=reference_moment,
                beta2=beta2,
            )
            return

        state[key] = _apply_projection(tensor, basis, resolved_proj_type)

    def _reset_param_state(
        self,
        optimizer: GaLore,
        param: Tensor,
        group: Mapping[str, Any],
        rank: int,
    ) -> None:
        state = optimizer.state[param]
        resolved_proj_type = _resolve_galore_proj_type(param, group.get("proj_type", STD_PROJ))

        for key in self.state_keys:
            tensor = state.get(key)
            if isinstance(tensor, Tensor):
                state[key] = self._zero_like_with_rank(tensor, rank, resolved_proj_type)

        if self.shared_source and self.shared_source not in self.state_keys:
            tensor = state.get(self.shared_source)
            if isinstance(tensor, Tensor):
                state[self.shared_source] = self._zero_like_with_rank(tensor, rank, resolved_proj_type)

        self._reset_bias_correction(state, param.device)
        self._update_projector_meta(state, group, rank)

    def _reset_bias_correction(self, state: dict[str, Any], device: torch.device) -> None:
        step_value = state.get("step")
        if isinstance(step_value, torch.Tensor):
            step_value.zero_()
        else:
            state["step"] = torch.zeros((), dtype=torch.float32, device=device)

    def _zero_like_with_rank(self, tensor: Tensor, rank: int, proj_type: str) -> Tensor:
        shape = list(tensor.shape)
        if not shape:
            shape = [rank]
        elif proj_type == RIGHT_PROJ:
            shape[-1] = rank
        elif proj_type == LEFT_PROJ:
            shape[0] = rank
        elif proj_type == FULL_PROJ:
            shape = [rank, rank]
        else:
            msg = f"Unsupported projection type {proj_type!r} for zero reinit."
            raise ValueError(msg)
        return tensor.new_zeros(tuple(shape))

    def _update_projector_meta(
        self,
        state: dict[str, Any],
        group: Mapping[str, Any],
        rank: int,
    ) -> None:
        state["projector_meta"] = {
            "rank": rank,
            "update_proj_gap": group.get("update_proj_gap", 200),
            "scale": group.get("scale", 1.0),
            "proj_type": group.get("proj_type", STD_PROJ),
        }
        state.pop("projector_basis", None)


class HyperparameterSwitchCallback(Callback):
    """Switch optimizer betas/vs at configured steps and optionally reset momenta."""

    def __init__(self, params: HyperparameterSwitchParams) -> None:
        self.enabled = params.enabled and bool(params.steps)
        self.steps = {int(step) for step in params.steps if step >= 0}
        self.new_vs = tuple(float(v) for v in params.new_vs) if params.new_vs is not None else None
        self.new_betas = tuple(float(b) for b in params.new_betas) if params.new_betas is not None else None
        self.reset_momenta = tuple(params.reset_momenta)
        self.log_metrics = params.log_metrics
        self._applied_steps: set[int] = set()

    def on_step_end(self, context: CallbackStepContext) -> None:
        """Apply configured hyperparameter switches when their step is reached."""
        if not self._should_apply(context.step):
            return

        optimizers = context.optimizers
        if optimizers is None:
            return

        self._apply_switches(optimizers, context.step)
        self._log_switch_metrics(context)
        logger.info("Hyperparameter switch callback applied at step %s", context.step)
        self._applied_steps.add(context.step)

    def _should_apply(self, step: int) -> bool:
        """Return ``True`` if the switch logic should run for the given step."""
        if not self.enabled:
            return False
        if step not in self.steps:
            return False
        return step not in self._applied_steps

    def _apply_switches(self, optimizers: Sequence[Optimizer], step: int) -> None:
        """Mutate optimizer hyperparameters according to the configured switches."""
        for optimizer in optimizers:
            if self.new_vs is not None:
                self._update_group_values(optimizer.param_groups, "vs", self.new_vs)
            if self.new_betas is not None:
                self._update_group_values(optimizer.param_groups, "betas", self.new_betas)
            if self.reset_momenta:
                self._reset_momenta(optimizer.state, step)
                mark_dirty = getattr(optimizer, "mark_state_dirty", None)
                if callable(mark_dirty):
                    mark_dirty()

    def _log_switch_metrics(self, context: CallbackStepContext) -> None:
        """Emit logging payload summarizing the applied hyperparameter switches."""
        if not self.log_metrics:
            return

        payload: dict[str, float] = {}
        if self.new_vs is not None:
            for idx, value in enumerate(self.new_vs):
                payload[f"hyper_switch/v{idx}"] = value
        if self.new_betas is not None:
            for idx, value in enumerate(self.new_betas, start=1):
                payload[f"hyper_switch/beta{idx}"] = value
        if payload:
            context.logger.log(payload, context.step)

    def _update_group_values(self, param_groups: list[dict[str, Any]], key: str, values: tuple[float, ...]) -> None:
        for group in param_groups:
            if key not in group:
                continue
            current_value = group[key]
            if isinstance(current_value, torch.Tensor):
                target = torch.tensor(values, device=current_value.device, dtype=current_value.dtype)
                if current_value.shape == target.shape:
                    current_value.copy_(target)
                else:
                    group[key] = target
            elif isinstance(current_value, Sequence) and not isinstance(current_value, (str, bytes)):
                group[key] = tuple(values)
            elif isinstance(current_value, (float, int)):
                # Preserve tuple semantics for optimizer hyperparams that expect sequences
                # (e.g., GaLore `vs` and many optimizers' `betas`). For scalar keys
                # where a single float is appropriate, fall back to a numeric value.
                if key in ("vs", "betas"):
                    group[key] = tuple(values)
                else:
                    group[key] = float(values[0])
            else:
                group[key] = tuple(values)

    def _reset_momenta(self, optimizer_state: dict[Any, dict[str, Any]], step: int) -> None:
        del step
        for state in optimizer_state.values():
            for name in self.reset_momenta:
                if name not in state:
                    continue
                self._zero_state_value(state[name])

    def _zero_state_value(self, value: Any) -> None:
        if isinstance(value, torch.Tensor):
            value.zero_()
        elif isinstance(value, dict):
            for inner in value.values():
                self._zero_state_value(inner)
        elif isinstance(value, (list, tuple)):
            for inner in value:
                self._zero_state_value(inner)


class FLMetricsProcessor(MetricsProcessor):
    """Extension of MetricsProcessor that wires the FL callback stack."""

    def __init__(  # noqa: PLR0913 - initializer wires multiple optional dependencies
        self,
        job_config: Any,
        parallel_dims: Any,
        metrics_config: MetricsConfig | None = None,
        *,
        unigram_manager: UnigramMetricManager | None = None,
        callbacks: Sequence[Callback] | None = None,
        tag: str | None = None,
    ) -> None:
        if metrics_config is None:
            metrics_config = job_config.fl_metrics

        super().__init__(job_config, parallel_dims, tag)

        assert metrics_config is not None
        self.metrics_config = metrics_config
        self.unigram_metrics = unigram_manager or UnigramMetricManager()

        self.optimizer_monitor: OptimizerMonitor | None = None
        self.activation_monitor: ActivationMonitor | None = None
        self.lr_monitor: LRMonitor | None = None
        self.betas_monitor: BetasMonitor | None = None
        self.vs_monitor: VSMonitor | None = None
        self.galore_projection: GaLoreMomentumProjectionCallback | None = None
        self.hyperparameter_switch: HyperparameterSwitchCallback | None = None

        if callbacks is None:
            self.callbacks = self._build_callbacks_from_config(metrics_config)
        else:
            self.callbacks = list(callbacks)
            self._assign_known_callbacks(self.callbacks)

        self._callbacks_setup_done = False

    def _build_callbacks_from_config(self, metrics_config: MetricsConfig) -> list[Callback]:
        callbacks: list[Callback] = []

        self.optimizer_monitor = self._init_optimizer_monitor(metrics_config.optimizer_monitor)
        if self.optimizer_monitor is not None:
            callbacks.append(self.optimizer_monitor)

        self.activation_monitor = self._init_activation_monitor(metrics_config.activation_monitor)
        if self.activation_monitor is not None:
            callbacks.append(self.activation_monitor)

        self.lr_monitor = self._init_lr_monitor(metrics_config.lr_monitor)
        if self.lr_monitor is not None:
            callbacks.append(self.lr_monitor)

        self.betas_monitor = self._init_betas_monitor(metrics_config.betas_monitor)
        if self.betas_monitor is not None:
            callbacks.append(self.betas_monitor)

        self.vs_monitor = self._init_vs_monitor(metrics_config.vs_monitor)
        if self.vs_monitor is not None:
            callbacks.append(self.vs_monitor)

        self.galore_projection = self._init_galore_projection(metrics_config.galore_projection)
        if self.galore_projection is not None:
            callbacks.append(self.galore_projection)

        self.hyperparameter_switch = self._init_hyperparameter_switch(metrics_config.hyperparameter_switch)
        if self.hyperparameter_switch is not None:
            callbacks.append(self.hyperparameter_switch)

        return callbacks

    def _assign_known_callbacks(self, callbacks: Sequence[Callback]) -> None:
        self.optimizer_monitor = next((cb for cb in callbacks if isinstance(cb, OptimizerMonitor)), None)
        self.activation_monitor = next((cb for cb in callbacks if isinstance(cb, ActivationMonitor)), None)
        self.lr_monitor = next((cb for cb in callbacks if isinstance(cb, LRMonitor)), None)
        self.betas_monitor = next((cb for cb in callbacks if isinstance(cb, BetasMonitor)), None)
        self.vs_monitor = next((cb for cb in callbacks if isinstance(cb, VSMonitor)), None)
        self.galore_projection = next(
            (cb for cb in callbacks if isinstance(cb, GaLoreMomentumProjectionCallback)),
            None,
        )
        self.hyperparameter_switch = next(
            (cb for cb in callbacks if isinstance(cb, HyperparameterSwitchCallback)),
            None,
        )

    def _init_optimizer_monitor(self, optimizer_config: OptimizerMonitorConfig) -> OptimizerMonitor | None:
        if optimizer_config.interval <= 0:
            return None
        return OptimizerMonitor(
            interval=optimizer_config.interval,
            only_global=optimizer_config.only_global,
            log_optimizer_metrics=optimizer_config.log_metrics,
        )

    def _init_activation_monitor(self, activation_config: ActivationMonitorConfig) -> ActivationMonitor | None:
        activation_enabled = activation_config.enabled or (activation_config.interval > 0)
        if not activation_enabled:
            return None
        return ActivationMonitor(
            interval=activation_config.interval,
            ignore_module_types=(
                activation_config.ignore_module_types if activation_config.ignore_module_types else ()
            ),
            gradient_accumulation_steps=activation_config.gradient_accumulation_steps,
            enabled_metrics=activation_config.enabled_metrics,
        )

    def _init_lr_monitor(self, lr_config: LRMonitorConfig) -> LRMonitor | None:
        if not (lr_config.enabled and lr_config.interval > 0):
            return None
        return LRMonitor(
            interval=lr_config.interval,
            enabled=lr_config.enabled,
        )

    def _init_betas_monitor(self, betas_config: BetasMonitorConfig) -> BetasMonitor | None:
        if not (betas_config.enabled and betas_config.interval > 0):
            return None
        return BetasMonitor(
            interval=betas_config.interval,
            enabled=betas_config.enabled,
        )

    def _init_vs_monitor(self, vs_config: VSMonitorConfig) -> VSMonitor | None:
        if not (vs_config.enabled and vs_config.interval > 0):
            return None
        return VSMonitor(
            interval=vs_config.interval,
            enabled=vs_config.enabled,
        )

    def _init_galore_projection(
        self,
        projection_config: GaLoreMomentumProjectionConfig,
    ) -> GaLoreMomentumProjectionCallback | None:
        if not (projection_config.enabled and projection_config.steps):
            return None

        ranks = tuple(int(rank) for rank in projection_config.target_ranks)
        if not ranks:
            target_rank = projection_config.target_rank
            if target_rank is None:
                return None
            ranks = (int(target_rank),)

        params = GaLoreMomentumProjectionParams(
            enabled=projection_config.enabled,
            steps=tuple(int(step) for step in projection_config.steps),
            target_ranks=ranks,
            state_keys=tuple(projection_config.state_keys),
            transform=str(projection_config.transform).lower(),
            proj_type=projection_config.proj_type,
            shared_source=projection_config.shared_source,
            column_count=projection_config.column_count,
            random_seed=projection_config.random_seed,
            random_std=float(projection_config.random_std),
            log_metrics=projection_config.log_metrics,
            reinit_mode=str(projection_config.reinit_mode).lower(),
        )
        projection_callback = GaLoreMomentumProjectionCallback(params)
        if not projection_callback.enabled:
            return None
        return projection_callback

    def _init_hyperparameter_switch(
        self, hyper_switch_config: HyperparameterSwitchConfig
    ) -> HyperparameterSwitchCallback | None:
        if not (hyper_switch_config.enabled and hyper_switch_config.steps):
            return None
        params = HyperparameterSwitchParams(
            enabled=hyper_switch_config.enabled,
            steps=tuple(int(step) for step in hyper_switch_config.steps),
            new_vs=(tuple(hyper_switch_config.new_vs) if hyper_switch_config.new_vs is not None else None),
            new_betas=(tuple(hyper_switch_config.new_betas) if hyper_switch_config.new_betas is not None else None),
            reset_momenta=tuple(hyper_switch_config.reset_momenta),
            log_metrics=hyper_switch_config.log_metrics,
        )
        switch_callback = HyperparameterSwitchCallback(params)
        if not switch_callback.enabled:
            return None
        return switch_callback

    def should_log(self, step: int) -> bool:
        """Determine if metrics should be logged at the current step.

        Args:
            step (int): Current training step.

        Returns:
            bool: True if metrics should be logged this step, False otherwise.
        """
        base_decision = super().should_log(step)
        if self.activation_monitor and self.activation_monitor.should_log_step(step):
            return True
        return base_decision

    def _build_unigram_payload(self, mesh: DeviceMesh | None) -> tuple[dict[str, float], dict[str, float]]:
        local_loss, local_items = self.unigram_metrics.collect(reset=False)

        device = torch.device("cpu")
        if self.model_parts:
            with contextlib.suppress(StopIteration):
                device = next(self.model_parts[0].parameters()).device

        loss_tensor = torch.tensor(float(local_loss), device=device, dtype=torch.float64)
        items_tensor = torch.tensor(float(local_items), device=device, dtype=torch.float64)

        if mesh is not None:
            reduced_loss = dist_utils.dist_sum(loss_tensor, mesh)
            reduced_items = dist_utils.dist_sum(items_tensor, mesh)
        else:
            reduced_loss = float(loss_tensor)
            reduced_items = float(items_tensor)

        local_avg = float(local_loss) / float(local_items) if local_items > 0 else 0.0
        local_payload = {
            "pure_unigram_cross_entropy/local": local_avg,
            "pure_unigram_cross_entropy/token_count/local": float(local_items),
        }

        global_items = float(reduced_items)
        global_payload: dict[str, float]
        if global_items > 0:
            global_payload = {
                "pure_unigram_cross_entropy": float(reduced_loss) / global_items,
                "pure_unigram_cross_entropy/token_count": global_items,
            }
        else:
            global_payload = {"pure_unigram_cross_entropy": 0.0}

        self.unigram_metrics.reset()
        return local_payload, global_payload

    def update_unigram_metrics(self, labels: Tensor) -> None:
        """Update tracked unigram metrics with the latest batch labels."""
        if not self.unigram_metrics.has_metrics():
            return
        self.unigram_metrics.update(labels)

    def _ensure_callbacks_setup(self) -> None:
        if self._callbacks_setup_done:
            return
        if not self.callbacks:
            self._callbacks_setup_done = True
            return
        if not self.model_parts or self.optimizers is None:
            return

        setup_context = CallbackSetupContext(
            model_parts=self.model_parts,
            optimizers=self.optimizers,
            logger=self.logger,
            parallel_dims=self.parallel_dims,
            job_config=self.job_config,
        )
        for callback in self.callbacks:
            callback.setup(setup_context)
        self._callbacks_setup_done = True

    def _run_step_callbacks(self, step: int, mesh: DeviceMesh | None) -> None:
        if not self.callbacks:
            return
        if not self.model_parts or self.optimizers is None:
            return

        context = CallbackStepContext(
            step=step,
            model_parts=self.model_parts,
            optimizers=self.optimizers,
            logger=self.logger,
            mesh=mesh,
        )
        for callback in self.callbacks:
            callback.on_step_end(context)

    def _run_validation_callbacks(self, loss: float, step: int) -> None:
        if not self.callbacks:
            return
        context = CallbackValidationContext(
            step=step,
            loss=loss,
            logger=self.logger,
        )
        for callback in self.callbacks:
            callback.on_validation_end(context)

    def log(
        self,
        step: int,
        global_avg_loss: float,
        global_max_loss: float,
        grad_norm: float,
        extra_metrics: dict[str, Any] | None = None,
    ) -> None:
        """Log the metrics at the end of the step.

        Args:
            step: The current training step.
            global_avg_loss: The average loss across all workers.
            global_max_loss: The maximum loss across all workers.
            grad_norm: The gradient norm.
            extra_metrics: Any additional metrics to log.

        """
        mesh = self.parallel_dims.world_mesh["dp_cp"] if self.parallel_dims.dp_cp_enabled else None
        local_unigram_payload, global_unigram_payload = self._build_unigram_payload(mesh)
        if local_unigram_payload:
            self.logger.log(local_unigram_payload, step)
        combined_metrics = dict(extra_metrics) if extra_metrics else {}
        if global_unigram_payload:
            combined_metrics.update(global_unigram_payload)

        super().log(
            step,
            global_avg_loss,
            global_max_loss,
            grad_norm,
            extra_metrics=combined_metrics or None,
        )

        self._ensure_callbacks_setup()
        self._run_step_callbacks(step, mesh)

    def log_validation(self, loss: float, step: int, local_loss: float | None) -> None:
        """Log validation metrics and run validation-specific callbacks."""
        super().log_validation(loss, step, local_loss=local_loss)
        mesh = self.parallel_dims.world_mesh["dp_cp"] if self.parallel_dims.dp_cp_enabled else None
        local_unigram_payload, global_unigram_payload = self._build_unigram_payload(mesh)
        if local_unigram_payload:
            self.logger.log(local_unigram_payload, step)
        if global_unigram_payload:
            self.logger.log(global_unigram_payload, step)
        self._ensure_callbacks_setup()
        self._run_validation_callbacks(loss, step)

    def close(self) -> None:
        """Close registered callbacks and flush any pending metrics."""
        for callback in self.callbacks:
            callback.close()
        super().close()
