from math import sqrt

import torch
from torch import Tensor

from .qr import qr
from .utils import scaled_matmul


def randomized_qb(
    A: Tensor,
    rank: int,
    niter: int = 0,
    test_matrix: str = "gauss",
    left: bool | None = None,
) -> tuple[Tensor, Tensor]:
    assert isinstance(A, Tensor)
    assert isinstance(rank, int)
    assert isinstance(niter, int)
    assert isinstance(test_matrix, str)
    assert isinstance(left, bool | None)
    assert rank > 0, f"rank must be positive, got {rank}"
    m, n = A.shape[-2], A.shape[-1]

    if (left is True) or (left is None and m <= n):
        k = min(m, n, rank)

        if test_matrix == "gauss":
            Ohm = torch.randn(*A.shape[:-2], n, k, dtype=A.dtype, device=A.device)  # (*, n, k)
        elif test_matrix == "subs":
            idx = torch.randperm(m)[:k]
            Ohm = A[..., idx, :].mT  # (*, n, k)
        else:
            raise ValueError("Invalid value of `test_matrix`.")

        Y = (A @ Ohm).div_(sqrt(n))  # (*, m, k)
        for _ in range(niter):
            Q, _ = qr(Y)  # (*, m, k)
            Ohm = scaled_matmul(A.mT, Q)  # (*, n, k)
            Y = (A @ Ohm).div_(sqrt(n))  # (*, m, k)
        Q, _ = qr(Y)  # (*, m, k)

        B = scaled_matmul(Q.mT, A) # (*, k, n)
        return Q, B

    elif (left is False) or (left is None and m > n):
        QT, BT = randomized_qb(A.mT, rank, niter=niter, test_matrix=test_matrix, left=True)
        return BT.mT, QT.mT

    else:
        raise ValueError("Invalid value of `left`, must be `True`, `False` or `None`.")


def qb_reconstruct(
    Q: Tensor,
    B: Tensor,
) -> Tensor:

    assert isinstance(Q, Tensor)
    assert isinstance(B, Tensor)

    return Q @ B
