# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# mypy: allow-untyped-defs
import logging
import math
from collections import defaultdict

import torch
import torch.distributed as dist
from torch.distributed import distributed_c10d

# Reuse torch's default hooks for vanilla allreduce future
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
from torch.utils._typing_utils import not_none


__all__ = ["PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook"]

logger = logging.getLogger(__name__)


def _orthogonalize(matrices, epsilon=0):
    assert len(matrices.shape) == 3 and matrices.shape[2] <= matrices.shape[1]

    num_matrices = matrices.shape[0]
    rank = matrices.shape[2]
    dtype = matrices.dtype
    if rank <= 2 or dtype in [torch.float16, torch.bfloat16]:
        _orthogonalize_gram_schmidt(matrices, epsilon=epsilon)
    else:
        torch.linalg.qr(
            matrices,
            out=(
                matrices,
                torch.empty(
                    num_matrices, rank, rank, device=matrices.device, dtype=dtype
                ),
            ),
        )


def _orthogonalize_gram_schmidt(matrices, epsilon=0):
    num_cols = matrices.shape[2]
    for i in range(num_cols):
        col = matrices[:, :, i : i + 1]
        if epsilon == 0:
            try:
                col /= torch.linalg.norm(col, dim=1, keepdim=True)
            except ZeroDivisionError:
                logger.error(
                    "The matrices to be orthogonalized has at least a column of all 0s. "
                    "Please set a small value such as 1e-8 as `orthogonalization_epsilon` in PowerSGD state."
                )
                col.fill_(0.0)
        else:
            col /= torch.linalg.norm(col, dim=1, keepdim=True) + epsilon
        if i + 1 < num_cols:
            rest = matrices[:, :, i + 1 :]
            rest -= torch.sum(col * rest, dim=1, keepdim=True) * col


def _should_compress(
    num_rows, num_cols, matrix_approximation_rank, min_compression_rate
):
    uncompressed_size = num_rows * num_cols
    compressed_size = (num_rows + num_cols) * matrix_approximation_rank
    return (
        compressed_size * min_compression_rate < uncompressed_size,
        uncompressed_size,
        compressed_size,
    )


def _report_compression_stats(bucket, state):
    if bucket.is_last() and state.iter >= state.next_stats_report:
        stats = state.compression_stats()
        logger.info(
            "Compression stats: iter %s, total before compression %s, total after compression %s, rate %s",
            state.iter,
            stats[1],
            stats[2],
            stats[0],
        )
        state.next_stats_report = state.iter + state.compression_stats_logging_frequency


def _ensure_buffer(
    t: torch.Tensor, *, device: torch.device, dtype: torch.dtype, numel: int
) -> torch.Tensor:
    """Ensure tensor has expected device/dtype/numel. Reallocate if needed."""
    if t.device != device or t.dtype != dtype or t.numel() != numel:
        return torch.empty(numel, device=device, dtype=dtype)
    return t


def _ensure_2d(
    t: torch.Tensor,
    *,
    device: torch.device,
    dtype: torch.dtype,
    shape: tuple[int, int],
    fill_random: bool,
    rng,
) -> torch.Tensor:
    rows, cols = shape
    expected_numel = rows * cols
    if (
        t is None
        or t.device != device
        or t.dtype != dtype
        or t.numel() != expected_numel
    ):
        if fill_random:
            with torch.random.fork_rng(devices=[]):
                torch.manual_seed(rng.randint(1_000_000_000))
                return torch.randn(rows, cols, device="cpu", dtype=dtype).to(device)
        else:
            return torch.empty(rows, cols, device=device, dtype=dtype)
    return t.view(rows, cols).to(device=device, dtype=dtype)


class PowerSGDState:
    __slots__ = [
        "process_group",
        "matrix_approximation_rank",
        "start_powerSGD_iter",
        "min_compression_rate",
        "orthogonalization_epsilon",
        "use_error_feedback",
        "warm_start",
        "batch_tensors_with_same_shape",
        "rng",
        "error_dict",
        "p_memory_dict",
        "q_memory_dict",
        "p_sig_dict",
        "q_sig_dict",
        "iter",
        "total_numel_before_compression",
        "total_numel_after_compression",
        "compression_stats_logging_frequency",
        "next_stats_report",
        "error_feedback_reset_frequency",
        # Per-step error norm logging placeholders
        "latest_error_norms",
        "curr_error_norms",
        "param_id_to_name",
        "bucket_index_to_param_names",
        "param_ptr_to_name",
        "model_weakref",
        "grad_ptr_to_name",
    ]

    def __init__(
        self,
        process_group,
        matrix_approximation_rank=1,
        start_powerSGD_iter=1_000,
        min_compression_rate=2,
        use_error_feedback=True,
        warm_start=True,
        orthogonalization_epsilon=0,
        random_seed=0,
        compression_stats_logging_frequency=10_000,
        batch_tensors_with_same_shape: bool = False,
        error_feedback_reset_frequency: int | None = None,
    ):
        logger.info(
            "PowerSGD config: matrix_approximation_rank = %s; start_powerSGD_iter = %s; min_compression_rate = %s; "
            "orthogonalization_epsilon = %s; use_error_feedback = %s; warm_start = %s; random_seed = %s; "
            "compression_stats_logging_frequency = %s; batch_tensors_with_same_shape = %s; error_feedback_reset_frequency = %s",
            matrix_approximation_rank,
            start_powerSGD_iter,
            min_compression_rate,
            orthogonalization_epsilon,
            use_error_feedback,
            warm_start,
            random_seed,
            compression_stats_logging_frequency,
            batch_tensors_with_same_shape,
            error_feedback_reset_frequency,
        )

        self.process_group = process_group
        if (use_error_feedback or warm_start) and start_powerSGD_iter <= 1:
            raise ValueError(
                "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
                "because PowerSGD can only be applied after the first two iterations in DDP."
            )
        self.matrix_approximation_rank = matrix_approximation_rank
        self.start_powerSGD_iter = start_powerSGD_iter
        self.min_compression_rate = min_compression_rate
        self.use_error_feedback = use_error_feedback
        self.warm_start = warm_start
        self.orthogonalization_epsilon = orthogonalization_epsilon

        import numpy as np

        self.rng = np.random.RandomState(random_seed)
        self.error_dict: dict[int, torch.Tensor] = {}
        self.p_memory_dict: dict[int, torch.Tensor] = {}
        self.q_memory_dict: dict[int, torch.Tensor] = {}
        # Alternative lookup keyed by a stable shape signature: tuple of (n, m, r) per compressed matrix,
        # sorted to be order-invariant. This survives DDP bucket reordering across runs.
        self.p_sig_dict: dict[tuple[tuple[int, int, int], ...], torch.Tensor] = {}
        self.q_sig_dict: dict[tuple[tuple[int, int, int], ...], torch.Tensor] = {}
        self.iter = 0
        self.total_numel_before_compression = 0
        self.total_numel_after_compression = 0
        self.compression_stats_logging_frequency = max(
            1, compression_stats_logging_frequency
        )
        self.next_stats_report = 0
        self.batch_tensors_with_same_shape = batch_tensors_with_same_shape
        self.error_feedback_reset_frequency = error_feedback_reset_frequency
        # Hold a snapshot of latest step's per-tensor error norms as a flat dict of
        # metric_name -> float. Also keep a transient accumulator during the step.
        self.latest_error_norms: dict[str, float] = {}
        self.curr_error_norms: dict[str, float] = {}
        self.param_id_to_name: dict[int, str] = {}
        self.bucket_index_to_param_names: dict[int, list[str]] = {}
        self.param_ptr_to_name: dict[int, str] = {}
        self.model_weakref = None
        self.grad_ptr_to_name: dict[int, str] = {}

    def __getstate__(self):
        logger.warning(
            "NOTE: Process group is not serializable and excluded from a saved state."
        )
        return {
            slot: getattr(self, slot)
            for slot in self.__slots__
            if slot not in ("process_group", "model_weakref")
        }

    def __setstate__(self, state):
        self.process_group = distributed_c10d._get_default_group()
        logger.warning(
            "NOTE: Process group will be set to a default group (i.e. the world size). \
            If a different group is desired, please set `self.process_group` after PowerSGD state is loaded."
        )
        for slot, value in state.items():
            setattr(self, slot, value)
        # Backward compatibility for checkpoints saved before p_sig_dict/q_sig_dict existed
        if not hasattr(self, "p_sig_dict"):
            self.p_sig_dict = {}
        if not hasattr(self, "q_sig_dict"):
            self.q_sig_dict = {}
        # Backward compatibility for checkpoints saved before error norm placeholders existed
        if not hasattr(self, "latest_error_norms"):
            self.latest_error_norms = {}
        if not hasattr(self, "curr_error_norms"):
            self.curr_error_norms = {}
        if not hasattr(self, "param_id_to_name"):
            self.param_id_to_name = {}
        if not hasattr(self, "bucket_index_to_param_names"):
            self.bucket_index_to_param_names = {}
        if not hasattr(self, "param_ptr_to_name"):
            self.param_ptr_to_name = {}
        if not hasattr(self, "model_weakref"):
            self.model_weakref = None
        if not hasattr(self, "grad_ptr_to_name"):
            self.grad_ptr_to_name = {}

    def maybe_increase_iter(self, bucket):
        if bucket.is_last():
            self.iter += 1
        if self.iter == self.start_powerSGD_iter:
            logger.info("Start to apply PowerSGD after %s iterations.", self.iter)

    def compression_stats(self):
        compress_rate = (
            self.total_numel_before_compression / self.total_numel_after_compression
            if self.total_numel_after_compression > 0
            else 0
        )
        return (
            compress_rate,
            self.total_numel_before_compression,
            self.total_numel_after_compression,
        )


def powerSGD_hook(  # noqa: N802
    state: PowerSGDState, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
    process_group = state.process_group
    group_to_use = (
        process_group if process_group is not None else not_none(dist.group.WORLD)
    )
    world_size = group_to_use.size()

    input_tensor = bucket.buffer()
    if state.iter < state.start_powerSGD_iter:
        # If we have preloaded buffers (resume), skip vanilla fallback and use PowerSGD immediately.
        # Otherwise, keep vanilla allreduce until start_powerSGD_iter.
        if not (state.p_memory_dict or state.q_memory_dict or state.error_dict):
            state.maybe_increase_iter(bucket)
            return default._allreduce_fut(group_to_use, input_tensor)

    if (
        state.error_feedback_reset_frequency is not None
        and state.iter % state.error_feedback_reset_frequency == 0
    ):
        # Reset error feedback periodically
        logger.info("Resetting error feedback for PowerSGD")
        state.error_dict.clear()

    device = input_tensor.device
    dtype = input_tensor.dtype

    bucket_index = bucket.index()
    input_tensor_cp = None
    total_length = input_tensor.shape[0]
    if state.use_error_feedback:
        prev = state.error_dict.get(bucket_index)
        if (
            prev is None
            or prev.numel() != total_length
            or prev.dtype != dtype
            or prev.device.type != "cpu"
            or not prev.is_pinned()
        ):
            state.error_dict[bucket_index] = torch.zeros(
                total_length, device="cpu", dtype=dtype
            ).pin_memory()
        else:
            # Stage CPU-pinned error to device asynchronously and add
            input_tensor.add_(prev.to(device, non_blocking=True))
        # Keep snapshot on device for minimal overhead
        input_tensor_cp = input_tensor.detach().clone()

    tensors = bucket.gradients()

    tensors_to_compress, uncompressed_tensors = [], []
    compressed_mask: list[bool] = []
    compress_shapes: list[tuple[int, int, int]] = []
    total_Ps_size = 0
    total_Qs_size = 0
    for tensor in tensors:
        matrix = tensor.view(tensor.shape[0], -1)
        n, m = matrix.shape
        matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
        compress_test = _should_compress(
            n, m, matrix_approximation_rank, state.min_compression_rate
        )
        state.total_numel_before_compression += compress_test[1]
        if compress_test[0]:
            tensors_to_compress.append(matrix)
            compress_shapes.append((n, m, matrix_approximation_rank))
            total_Ps_size += n * matrix_approximation_rank
            total_Qs_size += m * matrix_approximation_rank
            state.total_numel_after_compression += compress_test[2]
            compressed_mask.append(True)
        else:
            uncompressed_tensors.append(tensor)
            state.total_numel_after_compression += compress_test[1]
            compressed_mask.append(False)

    _report_compression_stats(bucket, state)

    uncompressed_tensors_memory = (
        torch.cat([tensor.view(-1) for tensor in uncompressed_tensors])
        if uncompressed_tensors
        else torch.tensor([], device=device, dtype=dtype)
    )

    need_randomize_qs = False
    existing_p = state.p_memory_dict.get(bucket_index)
    existing_q = state.q_memory_dict.get(bucket_index)
    required_p = total_Ps_size
    required_q = total_Qs_size
    signature = tuple(sorted(compress_shapes))

    if not state.warm_start:
        need_randomize_qs = True
        state.p_memory_dict[bucket_index] = torch.empty(
            required_p, device=device, dtype=dtype
        )
        state.q_memory_dict[bucket_index] = torch.empty(
            required_q, device=device, dtype=dtype
        )
    else:
        chosen_p = None
        chosen_q = None
        # Prefer existing by bucket index if sizes match
        if existing_p is not None and existing_q is not None:
            fixed_p = _ensure_buffer(
                existing_p, device=device, dtype=dtype, numel=required_p
            )
            fixed_q = _ensure_buffer(
                existing_q, device=device, dtype=dtype, numel=required_q
            )
            if fixed_p.numel() == required_p and fixed_q.numel() == required_q:
                chosen_p, chosen_q = fixed_p, fixed_q
        # Otherwise try signature-based reuse (survives bucket reordering)
        if chosen_p is None or chosen_q is None:
            sig_p = state.p_sig_dict.get(signature)
            sig_q = state.q_sig_dict.get(signature)
            if sig_p is not None and sig_q is not None:
                chosen_p = _ensure_buffer(
                    sig_p, device=device, dtype=dtype, numel=required_p
                )
                chosen_q = _ensure_buffer(
                    sig_q, device=device, dtype=dtype, numel=required_q
                )
        # If still none, allocate new
        if chosen_p is None or chosen_q is None:
            need_randomize_qs = True
            chosen_p = torch.empty(required_p, device=device, dtype=dtype)
            chosen_q = torch.empty(required_q, device=device, dtype=dtype)
            logger.info(
                "Allocating contiguous memory of length %s for Ps, and of length %s for Qs, respectively.",
                required_p,
                required_q,
            )
        state.p_memory_dict[bucket_index] = chosen_p
        state.q_memory_dict[bucket_index] = chosen_q
        state.p_sig_dict[signature] = state.p_memory_dict[bucket_index]
        state.q_sig_dict[signature] = state.q_memory_dict[bucket_index]

    shape_to_tensors = defaultdict(list)
    for tensor in tensors_to_compress:
        shape_to_tensors[tensor.shape].append(tensor)

    def maybe_batched_tensors_to_compress():
        for tensors_ in shape_to_tensors.values():
            if state.batch_tensors_with_same_shape:
                batch_size = len(tensors_)
                if batch_size == 1:
                    yield tensors_[0].unsqueeze(0)
                else:
                    yield torch.stack(tensors_)
            else:
                for tensor_ in tensors_:
                    yield tensor_.unsqueeze(0)

    tensors_to_compress = []
    ps = []
    qs = []
    p_idx = 0
    q_idx = 0
    for tensor in maybe_batched_tensors_to_compress():
        batch_size, n, m = tensor.shape
        matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
        tensors_to_compress.append(tensor)
        ps.append(
            state.p_memory_dict[bucket_index][
                p_idx : p_idx + batch_size * n * matrix_approximation_rank
            ].view(batch_size, n, matrix_approximation_rank)
        )
        qs.append(
            state.q_memory_dict[bucket_index][
                q_idx : q_idx + batch_size * m * matrix_approximation_rank
            ].view(batch_size, m, matrix_approximation_rank)
        )
        p_idx += batch_size * n * matrix_approximation_rank
        q_idx += batch_size * m * matrix_approximation_rank

    if not need_randomize_qs:
        for q in qs:
            _orthogonalize(q, state.orthogonalization_epsilon)
    else:
        with torch.random.fork_rng(devices=[]):
            torch.manual_seed(state.rng.randint(1_000_000_000))
            for q in qs:
                q.copy_(
                    torch.randn(
                        *q.shape,
                        device="cpu",
                        dtype=dtype,
                    )
                )
                _orthogonalize(q, state.orthogonalization_epsilon)

    for tensor, q, p in zip(tensors_to_compress, qs, ps):
        torch.bmm(tensor, q, out=p)

    allreduce_contiguous_uncompressed_tensors_fut = dist.all_reduce(
        uncompressed_tensors_memory, group=group_to_use, async_op=True
    ).get_future()

    def unpack_uncompressed_tensors_and_allreduce_ps(fut):
        uncompressed_tensors_memory_ = fut.value()[0].div_(world_size)
        idx = 0
        for tensor in uncompressed_tensors:
            tensor.copy_(
                uncompressed_tensors_memory_[idx : idx + tensor.numel()].view_as(tensor)
            )
            idx += tensor.numel()
        return (
            dist.all_reduce(
                state.p_memory_dict[bucket_index], group=group_to_use, async_op=True
            )
            .get_future()
            .wait()[0]
        )

    def compute_qs(fut):
        state.p_memory_dict[bucket_index] = fut.value()
        for p in ps:
            _orthogonalize(p, state.orthogonalization_epsilon)
        for tensor, p, q in zip(tensors_to_compress, ps, qs):
            torch.bmm(tensor.transpose(1, 2), p, out=q)
        return (
            dist.all_reduce(
                state.q_memory_dict[bucket_index], group=group_to_use, async_op=True
            )
            .get_future()
            .wait()[0]
        )

    def decompress(fut):
        state.q_memory_dict[bucket_index] = fut.value().div_(world_size)
        for p, q, tensor in zip(ps, qs, tensors_to_compress):
            torch.bmm(p, q.transpose(1, 2), out=tensor)
        if state.batch_tensors_with_same_shape:
            for tensor in tensors_to_compress:
                if tensor.shape[0] == 1:
                    continue
                original_tensors = shape_to_tensors[tensor.shape[1:]]
                for i, original_tensor in enumerate(original_tensors):
                    original_tensor.copy_(tensor[i])
        if torch.cuda.is_available():
            torch.cuda.synchronize(device)
        # Compute per-tensor error norms if a snapshot exists (requires EF)
        if input_tensor_cp is not None:
            try:
                err = input_tensor_cp - input_tensor
                # Also log the norm of the pre-communication snapshot for reference
                try:
                    state.curr_error_norms[
                        f"psgd_input_cp_norm/b{bucket_index}"
                    ] = float(
                        torch.linalg.norm(input_tensor_cp.float()).detach().cpu().item()
                    )
                except Exception:
                    pass
                # Cache param names for this bucket if available
                if bucket_index not in state.bucket_index_to_param_names:
                    names = None
                    try:
                        params = bucket.parameters()  # type: ignore[attr-defined]
                        if params is not None:
                            tmp = []
                            for i, p in enumerate(params):
                                name = state.param_id_to_name.get(id(p))
                                if name is None:
                                    try:
                                        ptr = p.untyped_storage().data_ptr()
                                        name = state.param_ptr_to_name.get(ptr)
                                    except Exception:
                                        name = None
                                if name is None:
                                    name = None
                                tmp.append(name)
                            names = tmp
                    except Exception:
                        names = None
                    # Fallback resolution via model.named_parameters order if needed
                    if names is None or any(n is None for n in names):
                        try:
                            model = (
                                state.model_weakref() if state.model_weakref else None
                            )
                            if model is not None:
                                # Try to resolve by matching grad pointers to bucket slice pointers
                                element_size = input_tensor.element_size()
                                grad_ptr_to_name = {}
                                for pname, p in model.named_parameters(recurse=True):
                                    g = p.grad
                                    if g is None:
                                        continue
                                    try:
                                        grad_ptr_to_name[g.data_ptr()] = pname
                                    except Exception:
                                        pass
                                resolved = [None] * len(tensors)
                                start_idx_local = 0
                                for idx, t in enumerate(tensors):
                                    numel = t.numel()
                                    ptr = (
                                        input_tensor.data_ptr()
                                        + start_idx_local * element_size
                                    )
                                    start_idx_local += numel
                                    name = grad_ptr_to_name.get(ptr)
                                    resolved[idx] = name
                                if names is None:
                                    names = resolved
                                else:
                                    for idx, n in enumerate(names):
                                        if n is None:
                                            names[idx] = resolved[idx]
                        except Exception:
                            pass
                    # Final fallback to bucket-indexed names
                    if names is None:
                        names = [f"b{bucket_index}_t{i}" for i in range(len(tensors))]
                    else:
                        for i in range(len(names)):
                            if names[i] is None:
                                names[i] = f"b{bucket_index}_t{i}"
                    state.bucket_index_to_param_names[bucket_index] = names
                # Per-original-tensor offsets inside the bucket buffer
                offsets: list[tuple[int, int]] = []
                start_idx = 0
                for t in tensors:
                    numel = t.numel()
                    offsets.append((start_idx, numel))
                    start_idx += numel
                # Accumulate error norms only for compressed tensors
                for i, (start_idx, length) in enumerate(offsets):
                    # Log input snapshot norm per parameter (always)
                    try:
                        cp_sl = input_tensor_cp.narrow(0, start_idx, length)
                        cp_val = torch.linalg.norm(cp_sl.float()).detach().cpu().item()
                        names = state.bucket_index_to_param_names.get(bucket_index)
                        if names is not None and i < len(names):
                            cp_key = f"psgd_input_cp_norm/{names[i]}"
                        else:
                            cp_key = f"psgd_input_cp_norm_b{bucket_index}_t{i}"
                        state.curr_error_norms[cp_key] = float(cp_val)
                    except Exception:
                        pass
                    if not compressed_mask[i]:
                        continue
                    sl = err.narrow(0, start_idx, length)
                    # Use float32 precision for stability
                    val = torch.linalg.norm(sl.float()).detach().cpu().item()
                    names = state.bucket_index_to_param_names.get(bucket_index)
                    if names is not None and i < len(names):
                        key = f"psgd_errnorm/{names[i]}"
                    else:
                        key = f"psgd_errnorm/b{bucket_index}_t{i}"
                    state.curr_error_norms[key] = float(val)
            except Exception:
                # Best-effort; avoid failing the step due to logging
                pass
            if state.use_error_feedback:
                # Immediately offload error to CPU pinned memory for EF
                state.error_dict[bucket_index] = (
                    err.detach().to("cpu", non_blocking=True).pin_memory()
                )
        if not state.warm_start:
            state.p_memory_dict.clear()
            state.q_memory_dict.clear()
        # If this is the last bucket of the iteration, publish snapshot
        if bucket.is_last():
            state.latest_error_norms = dict(state.curr_error_norms)
            state.curr_error_norms.clear()
        state.maybe_increase_iter(bucket)
        return input_tensor

    return (
        allreduce_contiguous_uncompressed_tensors_fut.then(
            unpack_uncompressed_tensors_and_allreduce_ps
        )
        .then(compute_qs)
        .then(decompress)
    )


def batched_powerSGD_hook(  # noqa: N802
    state: PowerSGDState, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
    process_group = state.process_group
    group_to_use = (
        process_group if process_group is not None else not_none(dist.group.WORLD)
    )
    world_size = group_to_use.size()

    input_tensor = bucket.buffer()
    if state.iter < state.start_powerSGD_iter:
        # If resuming with preloaded buffers, immediately use PowerSGD; else fallback.
        if not (state.p_memory_dict or state.q_memory_dict or state.error_dict):
            state.maybe_increase_iter(bucket)
            return default._allreduce_fut(group_to_use, input_tensor)

    if (
        state.error_feedback_reset_frequency is not None
        and state.iter % state.error_feedback_reset_frequency == 0
    ):
        # Reset error feedback periodically
        logger.info("Resetting error feedback for PowerSGD")
        state.error_dict.clear()

    device = input_tensor.device
    total_length = input_tensor.shape[0]
    state.total_numel_before_compression += total_length

    square_side_length = math.ceil(math.sqrt(total_length))
    state.total_numel_after_compression += (
        square_side_length * state.matrix_approximation_rank * 2
    )
    padded_total_length = square_side_length**2
    input_tensor.resize_(padded_total_length)
    input_tensor[total_length:padded_total_length].fill_(0)

    _report_compression_stats(bucket, state)

    bucket_index = bucket.index()
    input_tensor_cp = None
    if state.use_error_feedback:
        prev = state.error_dict.get(bucket_index)
        if (
            prev is None
            or prev.numel() != padded_total_length
            or prev.dtype != input_tensor.dtype
            or prev.device.type != "cpu"
            or not prev.is_pinned()
        ):
            state.error_dict[bucket_index] = torch.zeros(
                padded_total_length, device="cpu", dtype=input_tensor.dtype
            ).pin_memory()
        else:
            # Stage CPU-pinned error to device asynchronously and add
            input_tensor.add_(prev.to(device, non_blocking=True))
        # Keep snapshot on device for minimal overhead
        input_tensor_cp = input_tensor.detach().clone()

    if not state.warm_start or bucket_index not in state.p_memory_dict:
        if state.warm_start:
            logger.info(
                "Initializing low-rank tensors P and Q, each of which has a shape of %s x %s.",
                square_side_length,
                state.matrix_approximation_rank,
            )
        state.p_memory_dict[bucket_index] = _ensure_2d(
            None,
            device=device,
            dtype=input_tensor.dtype,
            shape=(square_side_length, state.matrix_approximation_rank),
            fill_random=False,
            rng=state.rng,
        )
        state.q_memory_dict[bucket_index] = _ensure_2d(
            None,
            device=device,
            dtype=input_tensor.dtype,
            shape=(square_side_length, state.matrix_approximation_rank),
            fill_random=True,
            rng=state.rng,
        )
    else:
        # Ensure buffers have expected shapes; if shape/device/dtype mismatches, reallocate and randomize Q
        expected_shape = (square_side_length, state.matrix_approximation_rank)
        existing_p = state.p_memory_dict[bucket_index]
        existing_q = state.q_memory_dict[bucket_index]
        new_p = _ensure_2d(
            existing_p,
            device=device,
            dtype=input_tensor.dtype,
            shape=expected_shape,
            fill_random=False,
            rng=state.rng,
        )
        new_q = _ensure_2d(
            existing_q,
            device=device,
            dtype=input_tensor.dtype,
            shape=expected_shape,
            fill_random=False,
            rng=state.rng,
        )
        if new_p.numel() != existing_p.numel() or new_q.numel() != existing_q.numel():
            # size mismatch: reinitialize Q randomly
            logger.info("PowerSGD: Size mismatch, reinitializing Q randomly")
            new_q = _ensure_2d(
                None,
                device=device,
                dtype=input_tensor.dtype,
                shape=expected_shape,
                fill_random=True,
                rng=state.rng,
            )
        state.p_memory_dict[bucket_index] = new_p
        state.q_memory_dict[bucket_index] = new_q

    _orthogonalize(state.q_memory_dict[bucket_index])

    matrix = input_tensor.view(square_side_length, square_side_length)
    torch.matmul(
        matrix, state.q_memory_dict[bucket_index], out=state.p_memory_dict[bucket_index]
    )
    allreduce_p_fut = dist.all_reduce(
        state.p_memory_dict[bucket_index], group=group_to_use, async_op=True
    ).get_future()

    def compute_q(fut):
        state.p_memory_dict[bucket_index] = fut.value()[0]
        _orthogonalize(state.p_memory_dict[bucket_index])
        torch.matmul(
            matrix.t(),
            state.p_memory_dict[bucket_index],
            out=state.q_memory_dict[bucket_index],
        )
        return (
            dist.all_reduce(
                state.q_memory_dict[bucket_index], group=group_to_use, async_op=True
            )
            .get_future()
            .wait()[0]
        )

    def decompress(fut):
        state.q_memory_dict[bucket_index] = fut.value().div_(world_size)
        torch.matmul(
            state.p_memory_dict[bucket_index],
            state.q_memory_dict[bucket_index].t(),
            out=matrix,
        )
        # Compute per-tensor error norms if snapshot exists
        if input_tensor_cp is not None:
            try:
                err = input_tensor_cp - input_tensor
                # Also log the norm of the pre-communication snapshot for reference
                try:
                    state.curr_error_norms[
                        f"psgd_input_cp_norm/b{bucket_index}"
                    ] = float(
                        torch.linalg.norm(input_tensor_cp.float()).detach().cpu().item()
                    )
                except Exception:
                    pass
                # Build offsets from bucket gradients to slice error per tensor
                tensors = bucket.gradients()
                # Cache param names for this bucket if available
                if bucket_index not in state.bucket_index_to_param_names:
                    names = None
                    try:
                        params = bucket.parameters()  # type: ignore[attr-defined]
                        if params is not None:
                            tmp = []
                            for i, p in enumerate(params):
                                name = state.param_id_to_name.get(id(p))
                                if name is None:
                                    try:
                                        ptr = p.untyped_storage().data_ptr()
                                        name = state.param_ptr_to_name.get(ptr)
                                    except Exception:
                                        name = None
                                tmp.append(name)
                            names = tmp
                    except Exception:
                        names = None
                    if names is None or any(n is None for n in names):
                        try:
                            model = (
                                state.model_weakref() if state.model_weakref else None
                            )
                            if model is not None:
                                ordered_names = [
                                    n for n, _ in model.named_parameters(recurse=True)
                                ]
                                if names is None:
                                    names = ordered_names[: len(tensors)]
                                else:
                                    for idx, n in enumerate(names):
                                        if n is None and idx < len(ordered_names):
                                            names[idx] = ordered_names[idx]
                        except Exception:
                            pass
                    if names is None:
                        names = [f"b{bucket_index}_t{i}" for i in range(len(tensors))]
                    else:
                        for i in range(len(names)):
                            if names[i] is None:
                                names[i] = f"b{bucket_index}_t{i}"
                    state.bucket_index_to_param_names[bucket_index] = names
                offsets: list[tuple[int, int]] = []
                start_idx = 0
                for t in tensors:
                    numel = t.numel()
                    offsets.append((start_idx, numel))
                    start_idx += numel
                for i, (start_idx, length) in enumerate(offsets):
                    # Avoid including padded tail beyond total_length
                    if start_idx >= total_length:
                        break
                    length = min(length, total_length - start_idx)
                    # Log input snapshot norm per parameter (always)
                    try:
                        cp_sl = input_tensor_cp.narrow(0, start_idx, length)
                        cp_val = torch.linalg.norm(cp_sl.float()).detach().cpu().item()
                        names = state.bucket_index_to_param_names.get(bucket_index)
                        if names is not None and i < len(names):
                            cp_key = f"psgd_input_cp_norm/{names[i]}"
                        else:
                            cp_key = f"psgd_input_cp_norm_b{bucket_index}_t{i}"
                        state.curr_error_norms[cp_key] = float(cp_val)
                    except Exception:
                        pass
                    sl = err.narrow(0, start_idx, length)
                    val = torch.linalg.norm(sl.float()).detach().cpu().item()
                    names = state.bucket_index_to_param_names.get(bucket_index)
                    if names is not None and i < len(names):
                        key = f"psgd_errnorm/{names[i]}"
                    else:
                        key = f"psgd_errnorm_b{bucket_index}_t{i}"
                    state.curr_error_norms[key] = float(val)
            except Exception:
                pass
            if state.use_error_feedback:
                # Offload error to CPU pinned memory for EF
                state.error_dict[bucket_index] = (
                    err.detach().to("cpu", non_blocking=True).pin_memory()
                )
        if torch.cuda.is_available():
            torch.cuda.synchronize(device)
        if not state.warm_start:
            state.p_memory_dict.clear()
            state.q_memory_dict.clear()
        ret = input_tensor.resize_(total_length)
        # Publish per-step snapshot on last bucket
        if bucket.is_last():
            state.latest_error_norms = dict(state.curr_error_norms)
            state.curr_error_norms.clear()
        state.maybe_increase_iter(bucket)
        return ret

    return allreduce_p_fut.then(compute_q).then(decompress)
