from math import sqrt

import torch
from torch import Tensor

from . import config


def truncated_eigh(
    A: Tensor,
    rank: int | None = None,
) -> tuple[Tensor, Tensor]:
    assert isinstance(A, Tensor)
    assert isinstance(rank, int | None)

    n = A.shape[-1]

    L, U = torch.linalg.eigh(A.float().nan_to_num_(0))

    if config.SCALING_UNIT:
        U.mul_(sqrt(n))
        L.div_(n)

    L = L.to(A.dtype).flip(-1)  # (*, n)
    U = U.to(A.dtype).flip(-1)  # (*, n, n)

    if rank is not None:
        L = L[..., :rank]  # (*, k)
        U = U[..., :, :rank]  # (*, n, k)

    L.mul_(L.gt(0))  # (*, k)
    U.mul_(L.gt(0).unsqueeze_(-2))  # (*, n, k)
    return L, U


def eigh_reconstruct(
    L: Tensor,
    U: Tensor,
) -> Tensor:
    assert isinstance(L, Tensor)
    assert isinstance(U, Tensor)

    return U @ L.diag_embed() @ U.mT
