# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""GaLore optimizer family for FL experiments."""

from __future__ import annotations

import logging
import math
import re
from typing import Any, NamedTuple, TYPE_CHECKING, cast

import torch
from torch import Tensor
from torch.optim import AdamW

from ._decoupled_decay import _compute_decay_factor
from ._metric_utils import prepare_metrics_for_reduction, reduce_metrics_across_ranks

if TYPE_CHECKING:
    from collections.abc import Callable, Iterable

log = logging.getLogger(__name__)

__all__ = ["GaLore", "classify_low_rank_parameters", ]

GALORE_MAX_SUPPORT_DIM = 2
_HIGH_WEIGHT_DECAY_WARNING = 1e-1

STD_PROJ = "std"
RIGHT_PROJ = "right"
LEFT_PROJ = "left"
FULL_PROJ = "full"
REV_STD_PROJ = "reverse_std"
PROJ_TO_CODE: dict[str, int] = {
    STD_PROJ: 0,
    REV_STD_PROJ: 1,
    LEFT_PROJ: 2,
    RIGHT_PROJ: 3,
    FULL_PROJ: 4,
}
CODE_TO_PROJ: dict[int, str] = {code: name for name, code in PROJ_TO_CODE.items()}

ProjectionBasis = Tensor | list[Tensor]


def _to_local_if_dtensor(tensor: Tensor) -> Tensor:
    """Ensure metrics use local tensors; DTensors can trip mixed-op checks during logging."""
    try:
        from torch.distributed.tensor import DTensor as _DTensor  # type: ignore
    except Exception:  # pragma: no cover - DTensor may be unavailable
        _DTensor = None  # type: ignore

    if _DTensor is not None and isinstance(tensor, _DTensor):
        return tensor.to_local()
    return tensor


class _RotationContext(NamedTuple):
    beta1: float
    beta2: float


def _flatten_for_svd(tensor: Tensor) -> Tensor:
    if tensor.ndim == 0:
        return tensor.reshape(1, 1)
    if tensor.ndim == 1:
        return tensor.reshape(1, -1)
    return tensor.reshape(tensor.shape[0], -1)


def _stable_rank_from_singular_values(singular_values: Tensor) -> Tensor:
    singular_values = _to_local_if_dtensor(singular_values)
    if singular_values.numel() == 0:
        return torch.tensor(float("nan"), device=singular_values.device, dtype=singular_values.dtype)
    fro_sq = singular_values.square().sum()
    spectral_sq = singular_values.max().square()
    if spectral_sq == 0:
        return torch.tensor(float("nan"), device=singular_values.device, dtype=singular_values.dtype)
    return fro_sq / spectral_sq


def _spectral_gap_from_singular_values(singular_values: Tensor, rank: int | None) -> Tensor:
    singular_values = _to_local_if_dtensor(singular_values)
    if rank is None or rank <= 0:
        return torch.tensor(float("nan"), device=singular_values.device, dtype=singular_values.dtype)

    sigma = torch.sort(singular_values, descending=True).values
    if sigma.numel() == 0:
        return torch.tensor(float("nan"), device=singular_values.device, dtype=singular_values.dtype)

    sigma = torch.clamp(sigma, min=1e-12)
    s_r = sigma[rank - 1] if sigma.numel() >= rank else torch.tensor(float("nan"), device=sigma.device, dtype=sigma.dtype)
    s_r1 = sigma[rank] if sigma.numel() > rank else torch.tensor(0.0, device=sigma.device, dtype=sigma.dtype)
    return s_r - s_r1


def _powerlaw_alpha_from_singular_values(singular_values: Tensor) -> Tensor:
    sigma = _to_local_if_dtensor(singular_values)
    sigma = torch.sort(sigma, descending=True).values
    sigma = torch.clamp(sigma, min=1e-12)

    k = torch.arange(1, sigma.numel() + 1, device=sigma.device, dtype=sigma.dtype)
    log_k = torch.log(k)
    log_sigma = torch.log(sigma)
    mean_log_k = log_k.mean()
    mean_log_sigma = log_sigma.mean()
    var_log_k = torch.sum((log_k - mean_log_k) ** 2)
    if var_log_k <= 0:
        return torch.tensor(float("nan"), device=sigma.device, dtype=sigma.dtype)
    cov = torch.sum((log_k - mean_log_k) * (log_sigma - mean_log_sigma))
    slope = cov / var_log_k
    return -slope


def _singular_values_from_gradient(grad: Tensor) -> Tensor:
    grad_f = _flatten_for_svd(grad).to(dtype=torch.float32)
    m, n = grad_f.shape
    if m >= n:
        gram = grad_f.transpose(-1, -2) @ grad_f  # shape (n, n)
    else:
        gram = grad_f @ grad_f.transpose(-1, -2)  # shape (m, m)
    eigenvalues = torch.linalg.eigvalsh(gram)
    eigenvalues = torch.clamp(eigenvalues, min=0.0)
    singular_values = torch.sqrt(eigenvalues)
    return torch.sort(singular_values, descending=True).values


def _apply_axis_transform(tensor: Tensor, matrix: Tensor, axis: int) -> Tensor:
    if axis == -1:
        original_shape = tensor.shape
        reshaped = tensor.reshape(-1, original_shape[-1])
        rotated = reshaped @ matrix
        return rotated.reshape(original_shape)
    if axis == 0:
        original_shape = tensor.shape
        reshaped = tensor.reshape(original_shape[0], -1)
        rotated = matrix @ reshaped
        return rotated.reshape(original_shape)
    msg = f"Unsupported axis {axis} for GaLore moment rotation."
    raise ValueError(msg)


def _rotate_moments_to_new_basis(
    state: dict[str, Any],
    *,
    old_basis: Tensor,
    new_basis: Tensor,
    proj_type: str,
    rotation_context: _RotationContext,
) -> None:
    if proj_type not in {LEFT_PROJ, RIGHT_PROJ}:
        return

    exp_avg: Tensor | None = state.get("exp_avg")
    exp_avg_sq: Tensor | None = state.get("exp_avg_sq")
    step_tensor: Tensor | None = state.get("step")
    if exp_avg is None or exp_avg_sq is None or step_tensor is None:
        return

    step_value = int(step_tensor.item())
    if step_value <= 0:
        return

    beta1_corr = 1.0 - rotation_context.beta1**step_value
    beta2_corr = 1.0 - rotation_context.beta2**step_value
    if beta1_corr <= 0.0 or beta2_corr <= 0.0:
        return

    device = exp_avg.device
    dtype = exp_avg.dtype
    old_basis_tensor = old_basis.to(device=device, dtype=dtype)
    new_basis_tensor = new_basis.to(device=device, dtype=dtype)

    old_columns = old_basis_tensor.T if proj_type == RIGHT_PROJ else old_basis_tensor
    new_columns = new_basis_tensor.T if proj_type == RIGHT_PROJ else new_basis_tensor
    transform = new_columns.transpose(-1, -2) @ old_columns
    coeff_matrix = transform.T if proj_type == RIGHT_PROJ else transform
    coeff_matrix = coeff_matrix.to(device=device, dtype=dtype)
    var_matrix = coeff_matrix.pow(2)
    axis = -1 if proj_type == RIGHT_PROJ else 0

    m_hat_old = exp_avg / beta1_corr
    v_hat_old = exp_avg_sq / beta2_corr
    var_hat_old = torch.clamp(v_hat_old - m_hat_old.pow(2), min=0.0)

    rotated_exp_avg = _apply_axis_transform(exp_avg, coeff_matrix, axis)
    m_hat_rot = _apply_axis_transform(m_hat_old, coeff_matrix, axis)
    var_hat_rot = _apply_axis_transform(var_hat_old, var_matrix, axis)
    v_hat_rot = torch.abs(var_hat_rot + m_hat_rot.pow(2))

    state["exp_avg"] = rotated_exp_avg
    state["exp_avg_sq"] = v_hat_rot * beta2_corr


def _infer_projector_rank(
    orthogonal: Tensor | list[Tensor] | None,
    resolved_proj_type: str | None,
) -> int | None:
    if orthogonal is None:
        return None
    if isinstance(orthogonal, Tensor):
        if resolved_proj_type == LEFT_PROJ:
            return orthogonal.shape[1]
        return orthogonal.shape[0]
    if isinstance(orthogonal, list) and orthogonal:
        left_matrix = orthogonal[0]
        if isinstance(left_matrix, Tensor):
            return left_matrix.shape[1]
    return None


def _orthogonal_matrix(weights: Tensor, rank: int, proj_type: str) -> ProjectionBasis:
    matrix = weights.data
    original_dtype = matrix.dtype
    original_device = matrix.device
    matrix = matrix.float()

    u_matrix, _, vh_matrix = torch.linalg.svd(matrix, full_matrices=False)
    if proj_type == RIGHT_PROJ:
        result = vh_matrix[:rank, :]
    elif proj_type == LEFT_PROJ:
        result = u_matrix[:, :rank]
    elif proj_type == FULL_PROJ:
        return [
            u_matrix[:, :rank].to(device=original_device, dtype=original_dtype),
            vh_matrix[:rank, :].to(device=original_device, dtype=original_dtype),
        ]
    else:
        msg = f"Unknown projection type {proj_type!r}."
        raise ValueError(msg)

    return result.to(device=original_device, dtype=original_dtype)


def _basis_similarity(old_basis: ProjectionBasis, new_basis: ProjectionBasis, proj_type: str) -> dict[str, float]:
    """Compute principal-angle similarity between projector bases.

    Similarity is mean(sigma^2) where sigma are singular values of the rotation
    between the two subspaces. Returns an empty dict if inputs are incompatible.
    """

    def _sim(a: Tensor, b: Tensor, *, left_space: bool) -> float:
        # For left bases (m x r), use Q_new^T Q_old; for right bases (r x n), use Q_new Q_old^T.
        rot = a.transpose(-2, -1) @ b if left_space else a @ b.transpose(-2, -1)
        sigma = torch.linalg.svdvals(rot)
        return float((sigma.pow(2)).mean().item())

    if proj_type == FULL_PROJ:
        if not (isinstance(old_basis, list) and isinstance(new_basis, list)):
            return {}
        if len(old_basis) != GALORE_MAX_SUPPORT_DIM or len(new_basis) != GALORE_MAX_SUPPORT_DIM:
            return {}
        new_left, new_right = new_basis
        old_left, old_right = old_basis
        return {
            "left": _sim(new_left, old_left, left_space=True),
            "right": _sim(new_right, old_right, left_space=False),
        }

    if isinstance(old_basis, Tensor) and isinstance(new_basis, Tensor):
        is_left = proj_type == LEFT_PROJ
        return {"single": _sim(new_basis, old_basis, left_space=is_left)}

    return {}


def _resolve_proj_choice(proj_type: str, tensor: Tensor) -> str:
    if proj_type in {STD_PROJ, REV_STD_PROJ}:
        if tensor.shape[0] >= tensor.shape[1]:
            return RIGHT_PROJ if proj_type == STD_PROJ else LEFT_PROJ
        return LEFT_PROJ if proj_type == STD_PROJ else RIGHT_PROJ
    return proj_type


def _proj_name_from_value(value: Any, default: str = STD_PROJ) -> str:
    if isinstance(value, str) and value in PROJ_TO_CODE:
        return value
    if isinstance(value, int) and value in CODE_TO_PROJ:
        return CODE_TO_PROJ[value]
    return default


def _canonicalize_projection_tensor(tensor: Tensor) -> Tensor:
    """Ensure tensors have at least 2 dims when building projectors."""
    if tensor.ndim >= GALORE_MAX_SUPPORT_DIM:
        return tensor
    if tensor.ndim == 1:
        return tensor.reshape(-1, 1)
    return tensor.reshape(1, 1)


def _maybe_refresh_projector(
    state: dict[str, Any],
    weights: Tensor,
    iteration: Tensor,
    rotation_context: _RotationContext | None = None,
) -> None:
    weights = _canonicalize_projection_tensor(weights)
    meta = state.setdefault(
        "projector_meta",
        {
            "rank": None,
            "update_proj_gap": None,
            "scale": None,
            "proj_type": PROJ_TO_CODE[STD_PROJ],
            "resolved_proj_type": PROJ_TO_CODE[STD_PROJ],
        },
    )
    rank = meta["rank"]
    update_proj_gap = meta["update_proj_gap"]
    proj_type = _proj_name_from_value(meta.get("proj_type", STD_PROJ))
    resolved_proj_type = _resolve_proj_choice(proj_type, weights)
    meta["proj_type"] = PROJ_TO_CODE[proj_type]
    meta["resolved_proj_type"] = PROJ_TO_CODE[resolved_proj_type]
    if rank is None or update_proj_gap is None:
        return

    orthogonal = state.get("projector_basis")
    should_refresh = orthogonal is None or (iteration % update_proj_gap).item() == 0
    if not should_refresh:
        return
    new_basis = _orthogonal_matrix(weights, rank, resolved_proj_type)
    if orthogonal is not None:
        similarity = _basis_similarity(orthogonal, new_basis, resolved_proj_type)
        if similarity:
            # Store similarity in state for metric logging
            if "single" in similarity:
                state["basis_similarity"] = torch.tensor(
                    similarity["single"], device=weights.device, dtype=torch.float32
                )
            elif "left" in similarity and "right" in similarity:
                # For FULL_PROJ, store average of left and right similarities
                avg_sim = (similarity["left"] + similarity["right"]) / 2.0
                state["basis_similarity"] = torch.tensor(
                    avg_sim, device=weights.device, dtype=torch.float32
                )
            log.info("GaLore basis similarity (%s): %s", resolved_proj_type, similarity)
    if (
        rotation_context is not None
        and orthogonal is not None
        and isinstance(orthogonal, Tensor)
        and isinstance(new_basis, Tensor)
    ):
        _rotate_moments_to_new_basis(
            state,
            old_basis=orthogonal,
            new_basis=new_basis,
            proj_type=resolved_proj_type,
            rotation_context=rotation_context,
        )
    state["projector_basis"] = new_basis


def _project(
    state: dict[str, Any],
    full_rank_grad: Tensor,
    iteration: Tensor,
    rotation_context: _RotationContext | None = None,
) -> Tensor:
    if full_rank_grad.ndim > GALORE_MAX_SUPPORT_DIM:
        msg = "GaLore currently supports tensors up to rank 2."
        raise NotImplementedError(msg)

    original_shape = tuple(full_rank_grad.shape)
    full_rank_grad = _canonicalize_projection_tensor(full_rank_grad)
    meta = state.get("projector_meta", {})
    proj_type_name = _proj_name_from_value(meta.get("proj_type", STD_PROJ))
    proj_type = _resolve_proj_choice(proj_type_name, full_rank_grad)
    meta["proj_type"] = PROJ_TO_CODE[proj_type_name]
    meta["resolved_proj_type"] = PROJ_TO_CODE[proj_type]
    if original_shape:
        meta["full_rank_shape"] = torch.tensor(
            list(original_shape),
            device=full_rank_grad.device,
            dtype=torch.int64,
        )
    state["projector_meta"] = meta
    _maybe_refresh_projector(state, full_rank_grad, iteration, rotation_context)
    orthogonal = state.get("projector_basis")
    if orthogonal is None:
        msg = "Projection matrix not initialised."
        raise RuntimeError(msg)

    if proj_type == RIGHT_PROJ:
        assert isinstance(orthogonal, Tensor)
        return full_rank_grad @ orthogonal.T.to(full_rank_grad.device)
    if proj_type == LEFT_PROJ:
        assert isinstance(orthogonal, Tensor)
        return orthogonal.T.to(full_rank_grad.device) @ full_rank_grad
    if proj_type == FULL_PROJ:
        assert isinstance(orthogonal, list)
        a_matrix, b_matrix = orthogonal
        return a_matrix.T.to(full_rank_grad.device) @ full_rank_grad @ b_matrix.T.to(full_rank_grad.device)
    msg = f"Unsupported projection type {proj_type!r}"
    raise ValueError(msg)


def _project_back(state: dict[str, Any], low_rank_grad: Tensor) -> Tensor:
    orthogonal = state.get("projector_basis")
    meta = state.get("projector_meta", {})
    scale = meta.get("scale", 1.0)
    if orthogonal is None:
        return low_rank_grad * scale

    if isinstance(orthogonal, Tensor):
        matrix = orthogonal.to(low_rank_grad.device)
        restored = low_rank_grad @ matrix if matrix.shape[0] == low_rank_grad.shape[-1] else matrix @ low_rank_grad
    else:
        a_matrix, b_matrix = orthogonal
        restored = a_matrix.to(low_rank_grad.device) @ low_rank_grad @ b_matrix.to(low_rank_grad.device)

    restored = restored * scale
    shape_tensor = meta.get("full_rank_shape")
    if isinstance(shape_tensor, torch.Tensor):
        desired_shape = tuple(int(x) for x in shape_tensor.tolist())
        if desired_shape and tuple(restored.shape) != desired_shape:
            restored = restored.reshape(desired_shape)
    return restored


class GaLore(AdamW):
    """GaLore optimiser with optional quasi-hyperbolic momentum."""

    metric_functions: dict[str, Callable[[Tensor, dict[str, Any], Tensor], Tensor]] = {
        "l2_norm/exp_avg": (
            lambda _param, optim_state, _step_tensor: torch.linalg.vector_norm(
                optim_state["exp_avg"],
            )
        ),
        "l2_norm/exp_avg_sq": (
            lambda _param, optim_state, _step_tensor: torch.linalg.vector_norm(
                optim_state["exp_avg_sq"],
            )
        ),
        "min/exp_avg_sq": lambda _param, optim_state, _step_tensor: torch.min(
            optim_state["exp_avg_sq"],
        ),
        "max/exp_avg_sq": lambda _param, optim_state, _step_tensor: torch.max(
            optim_state["exp_avg_sq"],
        ),
        "l2_norm/param": (
            lambda param, _optim_state, _step_tensor: torch.linalg.vector_norm(
                param.data,
            )
        ),
        "l2_norm/update": (
            lambda _param, _optim_state, step_tensor: torch.linalg.vector_norm(
                step_tensor,
            )
        ),
        "l2_norm/grad": (
            lambda param, _optim_state, _step_tensor: torch.linalg.vector_norm(
                param.grad,
            )
        ),
        "l2_norm/error_feedback": (
            lambda _param, optim_state, _step_tensor: torch.linalg.vector_norm(
                optim_state["error_feedback"],
            )
            if "error_feedback" in optim_state
            else torch.tensor(0.0, device=_param.device)
        ),
        "mean/basis_similarity": (
            lambda _param, optim_state, _step_tensor: optim_state["basis_similarity"]
            if "basis_similarity" in optim_state
            else torch.tensor(float("nan"), device=_param.device)
        ),
    }

    def __init__(  # noqa: C901, PLR0913
        self,
        params: Iterable[Tensor] | Iterable[dict],
        lr: float = 1e-3,
        betas: tuple[float, float] = (0.9, 0.95),
        eps: float = 1e-8,
        weight_decay: float = 1e-5,
        *,
        vs: tuple[float, ...] = (0.0,),
        rank: int | None = None,
        update_proj_gap: int = 200,
        scale: float = 1.0,
        proj_type: str = STD_PROJ,
        dim: int = 2,
        rank_overrides: dict[Tensor | int, int] | None = None,
        rotate_moments_on_refresh: bool = False,
        use_error_feedback: bool = False,
        qhm_outside_projection: bool = False,
    ) -> None:
        if lr < 0.0:
            msg = f"Invalid learning rate: {lr}"
            raise ValueError(msg)
        if not 0.0 <= betas[0] < 1.0 or not 0.0 <= betas[1] < 1.0:
            msg = f"Invalid betas: {betas}"
            raise ValueError(msg)
        if eps < 0.0:
            msg = f"Invalid epsilon value: {eps}"
            raise ValueError(msg)
        if weight_decay < 0.0:
            msg = f"Invalid weight_decay value: {weight_decay}"
            raise ValueError(msg)
        if weight_decay >= _HIGH_WEIGHT_DECAY_WARNING:
            log.warning(
                "High weight_decay=%s for GaLore. Model weights are multiplied by %.6f every step.",
                weight_decay,
                1.0 - weight_decay,
            )
        # Validate vs parameters (match QHAdamW style)
        if not vs or len(vs) < 1:
            msg = "vs must be a non-empty tuple with at least one element"
            raise ValueError(msg)
        for i, v in enumerate(vs):
            if not 0.0 <= v <= 1.0:
                msg = f"Invalid vs parameter at index {i}: {v}"
                raise ValueError(msg)
        super().__init__(params=params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        self.vs = vs
        self.use_error_feedback = use_error_feedback
        self._defaults = {
            "rank": rank,
            "update_proj_gap": update_proj_gap,
            "scale": scale,
            "vs": vs,
            "proj_type": proj_type,
            "dim": dim,
            "rotate_moments_on_refresh": rotate_moments_on_refresh,
            "use_error_feedback": use_error_feedback,
            "qhm_outside_projection": qhm_outside_projection,
        }
        overrides: dict[int, int] = {}
        if rank_overrides:
            for key, override_rank in rank_overrides.items():
                param_id = key if isinstance(key, int) else id(key)
                overrides[param_id] = override_rank
        self._param_rank_overrides = overrides
        for group in self.param_groups:
            group.setdefault("rank", rank)
            group.setdefault("update_proj_gap", update_proj_gap)
            group.setdefault("scale", scale)
            group.setdefault("vs", vs)
            group.setdefault("proj_type", proj_type)
            group.setdefault("dim", dim)
            group.setdefault("rotate_moments_on_refresh", rotate_moments_on_refresh)
            group.setdefault("use_error_feedback", use_error_feedback)
            group.setdefault("qhm_outside_projection", qhm_outside_projection)
            group["initial_lr"] = group["lr"]

        self.register_load_state_dict_post_hook(
            lambda optimizer: optimizer._repair_projector_states()  # type: ignore[attr-defined]
        )


    @torch.no_grad()
    def step(self, closure: Callable[[], Tensor] | None = None) -> Tensor | None:  # noqa: C901, D102, PLR0912, PLR0915
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            beta1, beta2 = group["betas"]
            v1, *_ = cast("tuple[float,...]", group.get("vs", (0.0,)))
            eps = group["eps"]
            lr = group["lr"]
            weight_decay = group["weight_decay"]
            base_rank = group.get("rank")
            dim = group.get("dim", GALORE_MAX_SUPPORT_DIM)
            qhm_outside = group.get("qhm_outside_projection", False)
            log.info(f"GaLore step with lr={lr}, weight_decay={weight_decay}, rank={base_rank}, dim={dim}, v1={v1}, qhm_outside={qhm_outside}, betas=({beta1}, {beta2}), eps={eps}")
            rotation_context = None
            log.info(f"Rotate Moments on Refresh: {group.get('rotate_moments_on_refresh', False)}")  # noqa: G004
            if group.get("rotate_moments_on_refresh", False):
                rotation_context = _RotationContext(beta1=beta1, beta2=beta2)

            for param in group["params"]:
                grad = param.grad
                if grad is None:
                    continue
                if grad.is_sparse:
                    msg = "GaLore does not support sparse gradients."
                    raise RuntimeError(msg)

                rank = self._resolve_rank_for_param(param, base_rank)
                use_low_rank = rank is not None
                print(f"rank {rank} use_low_rank {use_low_rank} for param id {id(param)}")
                if use_low_rank and dim > GALORE_MAX_SUPPORT_DIM:
                    msg = "GaLore supports tensors up to 2 dimensions."
                    raise NotImplementedError(msg)
                state = self.state[param]
                if "step" not in state:
                    state["step"] = torch.zeros((), dtype=torch.float32, device=param.device)
                # If we want to mix QHM outside projection, we must keep a reference
                # to the full rank gradient before overwriting 'grad' with the low-rank version.
                full_rank_grad = None
                if use_low_rank and qhm_outside:
                    full_rank_grad = grad
                if use_low_rank:
                    state.setdefault(
                        "projector_meta",
                        {
                            "rank": rank,
                            "update_proj_gap": group["update_proj_gap"],
                            "scale": group["scale"],
                            "proj_type": group["proj_type"],
                        },
                    )
                    meta = state["projector_meta"]
                    meta["rank"] = rank
                    meta["update_proj_gap"] = group["update_proj_gap"]
                    meta["scale"] = group["scale"]
                    meta["proj_type"] = group["proj_type"]

                    # PowerSGD-style error feedback: add accumulated error before projection
                    use_ef = group.get("use_error_feedback", False)
                    if use_ef:
                        error_feedback = state.get("error_feedback")
                        if error_feedback is not None:
                            grad = grad + error_feedback

                    # Project gradient to low-rank subspace
                    low_rank_grad = _project(state, grad, state["step"], rotation_context)

                    # Compute reconstruction error for error feedback
                    if use_ef:
                        reconstructed = _project_back(state, low_rank_grad)
                        # Reshape reconstructed if needed to match grad shape
                        if reconstructed.shape != grad.shape:
                            reconstructed = reconstructed.reshape(grad.shape)
                        new_error = grad - reconstructed
                        state["error_feedback"] = new_error

                    grad = low_rank_grad

                if "exp_avg" not in state:
                    state["exp_avg"] = torch.zeros_like(
                        grad,
                        memory_format=torch.preserve_format,
                    )
                    state["exp_avg_sq"] = torch.zeros_like(
                        grad,
                        memory_format=torch.preserve_format,
                    )

                exp_avg = state["exp_avg"]
                exp_avg_sq = state["exp_avg_sq"]
                state["step"].add_(1)

                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                step_count = state["step"]
                beta1_t = step_count.new_tensor(beta1)
                beta2_t = step_count.new_tensor(beta2)
                bias_correction1 = 1 - torch.pow(beta1_t, step_count)
                bias_correction2 = 1 - torch.pow(beta2_t, step_count)
                denom = exp_avg_sq.sqrt() / bias_correction2.sqrt() + eps

                # Calculate the adaptive step in low-rank space (Term 2)
                # This corresponds to: (exp_avg / bias_correction1) / denom
                adaptive_step_low_rank = (exp_avg / bias_correction1) / denom
                step_tensor = adaptive_step_low_rank
                has_been_projected_back = False

                if v1 > 0.0:
                    if use_low_rank and qhm_outside and full_rank_grad is not None:
                        # Variant: Apply QHM mixing AFTER up-projection.
                        # 1. Project the adaptive step back to full rank
                        step_tensor = _project_back(state, step_tensor)
                        # 2. Normalize the full rank gradient using scalar statistics from the low-rank denom.
                        # We cannot use the tensor 'denom' directly because it is low-rank.
                        # We cannot up-project 'denom' because it would be zero in the null space (div by zero).
                        # We use the mean of the low-rank curvature as a proxy for global curvature.
                        denom_scalar = denom.mean().item()
                        # Precondition the full rank gradient (approximate)
                        # This ensures the units of the gradient term match the adaptive term.
                        normalized_full_grad = full_rank_grad / (denom_scalar + eps)
                        # 3. Mix
                        step_tensor = (1.0 - v1) * normalized_full_grad + v1 * step_tensor
                        has_been_projected_back = True
                    else:
                        # Standard Variant: Apply QHM mixing INSIDE the (potentially low-rank) space.
                        # Here dimensions match, so we can use the exact tensor 'denom'.
                        log.debug("Using standard QHM inside projection.")
                        blended = (1 - v1) * grad + v1 * (exp_avg / bias_correction1)
                        step_tensor = blended / denom

                if use_low_rank and not has_been_projected_back:
                    step_tensor = _project_back(state, step_tensor)

                param.add_(step_tensor, alpha=-lr)
                if weight_decay > 0.0:
                    param.add_(param, alpha=-lr * weight_decay)
        return loss

    @staticmethod
    def pre_reduce_metrics(optimizer_metrics: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Prepare metrics for distributed reduction."""
        return prepare_metrics_for_reduction(optimizer_metrics)

    @staticmethod
    def dist_reduce_metrics(optimizer_metrics: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Reduce metrics across ranks."""
        return reduce_metrics_across_ranks(optimizer_metrics)

    def _projector_eigenvalues(
        self,
        optim_state: dict[str, Any],
        device: torch.device,
    ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
        basis = optim_state.get("projector_basis")
        if basis is None:
            return None, None

        if isinstance(basis, Tensor):
            proj_matrix = basis @ basis.T
        elif isinstance(basis, list):
            left_basis, _ = basis
            proj_matrix = left_basis @ left_basis.T
        else:
            return None, None

        proj_matrix = proj_matrix.to(device=device)
        eigenvalues = torch.linalg.eigvalsh(proj_matrix).real
        return eigenvalues, torch.prod(eigenvalues)

    def report_per_parameter_metrics(  # noqa: C901, PLR0912, PLR0915
        self,
        param: torch.Tensor,
        name: str,
        optimizer_metrics: dict[str, torch.Tensor],
    ) -> dict[str, torch.Tensor]:
        """Report per-parameter metrics including GaLore projection stats."""
        lr = self.param_groups[0]["lr"]
        eps = self.param_groups[0]["eps"]
        weight_decay = self.param_groups[0]["weight_decay"]
        initial_lr = self.param_groups[0]["initial_lr"]
        qhm_outside = self.param_groups[0].get("qhm_outside_projection", False)
        beta1, beta2 = self.param_groups[0]["betas"]
        v1, *_ = cast("tuple[float,...]", self.param_groups[0].get("vs", (0.0,)))
        if param in self.state:
            param_optim_state = self.state[param]
            step_state = param_optim_state["step"]
            if "max/optimizer_step" not in optimizer_metrics:
                if isinstance(step_state, torch.Tensor):
                    step_tensor = step_state.detach().clone()
                    if step_tensor.device != param.device:
                        step_tensor = step_tensor.to(param.device)
                else:
                    step_tensor = torch.tensor(float(step_state), device=param.device)
                optimizer_metrics["max/optimizer_step"] = step_tensor

            step = param_optim_state["step"].item()
            grad = param.grad
            original_grad = grad.detach() if grad is not None else None
            meta = param_optim_state.get("projector_meta")
            use_low_rank = meta is not None and meta.get("rank") is not None
            # Hold reference to full rank grad for outside projection logic
            full_rank_grad = grad
            # Project grad to low rank for standard calculations
            if grad is not None and use_low_rank:
                grad = _project(param_optim_state, grad, param_optim_state["step"])

            bias_correction1 = 1 - beta1**step
            bias_correction2 = 1 - beta2**step
            # --- Common Components ---
            denom = param_optim_state["exp_avg_sq"].sqrt() / math.sqrt(bias_correction2) + eps
            m_hat = param_optim_state["exp_avg"] / bias_correction1
            # The adaptive term (preconditioned momentum)
            adaptive_step = m_hat / denom
            step_tensor = adaptive_step
            has_been_projected_back = False

            # --- Mixing Logic ---
            if v1 > 0.0 and grad is not None:
                if use_low_rank and qhm_outside and full_rank_grad is not None:
                    # Logic A: QHM Outside Projection (Full Rank Mixing)
                    # 1. Project adaptive term back to full rank
                    adaptive_step_full = _project_back(param_optim_state, adaptive_step)
                    # 2. Scalar normalize the full rank gradient
                    # (approximating the preconditioner with the mean of the low-rank subspace)
                    denom_scalar = denom.mean().item()
                    grad_norm = full_rank_grad / (denom_scalar + eps)
                    # 3. Mix
                    step_tensor = (1.0 - v1) * grad_norm + v1 * adaptive_step_full
                    has_been_projected_back = True
                else:
                    # Logic B: Standard QHM (Low Rank Mixing)
                    blended = (1.0 - v1) * grad + v1 * m_hat
                    step_tensor = blended / denom

            # --- Final Projection ---
            if use_low_rank and not has_been_projected_back:
                step_tensor = _project_back(param_optim_state, step_tensor)
            # --- Apply LR & Decay ---
            step_tensor = step_tensor.mul(lr)

            if weight_decay != 0:
                decay_factor = _compute_decay_factor(lr, initial_lr)
                scaling_factor = (decay_factor * weight_decay) / (1 - decay_factor * weight_decay)
                step_tensor.mul_(1 + scaling_factor).add_(param, alpha=scaling_factor)

            for metric in self.metric_functions:
                metric_value = self.metric_functions[metric](
                    param,
                    param_optim_state,
                    step_tensor,
                )
                optimizer_metrics[f"{metric}/{name}"] = _to_local_if_dtensor(metric_value)

            if use_low_rank:
                eigenvalues, eig_product = self._projector_eigenvalues(param_optim_state, param.device)
                if eigenvalues is not None:
                    for idx, eig in enumerate(eigenvalues):
                        optimizer_metrics[f"mean/projection_eigenvalue_{idx}/{name}"] = _to_local_if_dtensor(eig)
                if eig_product is not None:
                    optimizer_metrics[f"mean/projection_eigenvalue_product/{name}"] = _to_local_if_dtensor(
                        eig_product
                    )

            if original_grad is not None:
                with torch.no_grad():
                    singular_values = _singular_values_from_gradient(original_grad)
                    stable_rank = _stable_rank_from_singular_values(singular_values)
                    optimizer_metrics[f"mean/stable_rank/{name}"] = _to_local_if_dtensor(stable_rank)

                    rank = meta.get("rank") if meta is not None else None
                    if rank is not None:
                        spectral_gap = _spectral_gap_from_singular_values(singular_values, rank)
                        optimizer_metrics[f"mean/spectral_gap/{name}"] = _to_local_if_dtensor(spectral_gap)
                        sigma_sorted = torch.sort(singular_values, descending=True).values
                        sigma_r = sigma_sorted[rank - 1] if sigma_sorted.numel() >= rank else torch.tensor(float("nan"), device=sigma_sorted.device)
                        if torch.isfinite(sigma_r) and sigma_r != 0:
                            rel_gap = spectral_gap / sigma_r
                            optimizer_metrics[f"mean/relative_spectral_gap/{name}"] = _to_local_if_dtensor(rel_gap)

                    alpha = _powerlaw_alpha_from_singular_values(singular_values)
                    optimizer_metrics[f"mean/powerlaw_alpha/{name}"] = _to_local_if_dtensor(alpha)

        return optimizer_metrics

    def _repair_projector_states(self) -> None:
        """Ensure projector metadata matches the configured rank after load."""
        for group in self.param_groups:
            base_rank = group.get("rank")
            update_proj_gap = group.get("update_proj_gap")
            scale = group.get("scale")
            proj_type_name = _proj_name_from_value(group.get("proj_type", STD_PROJ))
            for param in group["params"]:
                desired_rank = self._resolve_rank_for_param(param, base_rank)
                if desired_rank is None or param not in self.state:
                    continue
                state = self.state[param]
                meta = state.setdefault(
                    "projector_meta",
                    {
                        "rank": desired_rank,
                        "update_proj_gap": update_proj_gap,
                        "scale": scale,
                        "proj_type": PROJ_TO_CODE[proj_type_name],
                        "resolved_proj_type": PROJ_TO_CODE[
                            _resolve_proj_choice(proj_type_name, param)
                        ],
                    },
                )
                meta["rank"] = desired_rank
                meta["update_proj_gap"] = update_proj_gap
                meta["scale"] = scale
                meta_proj_type = _proj_name_from_value(meta.get("proj_type", proj_type_name))
                meta["proj_type"] = PROJ_TO_CODE[meta_proj_type]
                resolved_proj_type = _resolve_proj_choice(meta_proj_type, param)
                meta["resolved_proj_type"] = PROJ_TO_CODE[resolved_proj_type]

                current_rank = _infer_projector_rank(
                    state.get("projector_basis"), resolved_proj_type
                )
                if current_rank is None or current_rank == desired_rank:
                    continue

                log.warning(
                    "Resetting projector basis for %s: checkpoint rank %s != configured rank %s.",
                    getattr(param, "_base_name", "<unnamed_param>"),
                    current_rank,
                    desired_rank,
                )
                state.pop("projector_basis", None)

    def _resolve_rank_for_param(self, param: Tensor, fallback: int | None) -> int | None:
        override = self._param_rank_overrides.get(id(param))
        if override is not None:
            return override
        return fallback


def classify_low_rank_parameters(  # noqa: C901
    parameter_names: list[str],
    optimizer_config: dict | None = None,
) -> dict[str, int]:
    """Classify parameter names as low-rank based on config patterns."""
    if not optimizer_config:
        return {}
    param_groups = optimizer_config.get("param_groups")
    regex_overrides = optimizer_config.get("galore_param_regexes") or []
    default_rank = optimizer_config.get("galore_rank")

    low_rank: dict[str, int] = {}
    remaining = set(parameter_names)

    if param_groups:
        for group in param_groups:
            pattern = group.get("param_str_match")
            rank = group.get("rank", default_rank)
            if not pattern or not isinstance(rank, int):
                continue
            compiled = re.compile(pattern)
            for name in list(remaining):
                if compiled.search(name):
                    low_rank[name] = rank
                    remaining.remove(name)

    for override in regex_overrides:
        pattern = override.get("param_str_match")
        rank = override.get("rank")
        if not pattern or not isinstance(rank, int):
            continue
        compiled = re.compile(pattern)
        for name in list(remaining):
            if compiled.search(name):
                low_rank[name] = rank
                remaining.remove(name)

    if default_rank is not None:
        for name in remaining:
            low_rank[name] = default_rank
    return low_rank
