from math import sqrt

import torch
from torch import Tensor

from . import config


def qr(
    A: Tensor,
) -> tuple[Tensor, Tensor]:

    assert isinstance(A, Tensor)

    m, n = A.shape[-2], A.shape[-1]

    Q, R = torch.linalg.qr(A.float().nan_to_num_(0))

    if config.SCALING_UNIT:
        Q.mul_(sqrt(m))
        R.div_(sqrt(m))

    Q = Q.to(A.dtype)
    R = R.to(A.dtype)
    return Q, R


def qr_reconstruct(
    Q: Tensor,
    R: Tensor,
) -> Tensor:

    assert isinstance(Q, Tensor)
    assert isinstance(R, Tensor)

    return Q @ R
