import math
import torch

def orthonormal_operator(m, n, dtype):
    """
    Generates an orthogonal matrix Q with m rows and n columns.
    If m >= n, then Q has orthogonal columns, otherwise, Q has orthogonal rows.
    """
    if m < n:
        return orthonormal_operator(n, m, dtype).T
    else:
        scale = math.sqrt(1/n)
        Q = torch.randn((m, n), dtype=dtype) * scale
        Q, R = torch.linalg.qr(Q)
        Q = Q * torch.sign(torch.diag(R))
    return Q

def fit_newtonschulz_coeff(a, b, c, d, low=0, high=1, x0=0.5, epsilon=1.0e-8):
    if math.fabs(a) < epsilon and math.fabs(b) < epsilon and math.fabs(c) < epsilon:
        return x0
    x_opt = x0
    f_opt = a*x_opt**4 + b*x_opt**3 + c*x_opt**2 + d*x_opt
    for x in [low, high]:
        f_val = a*x**4 + b*x**3 + c*x**2 + d*x
        if f_val < f_opt:
            f_opt = f_val
            x_opt = x
    for x in cubic_roots_real(4*a, 3*b, 2*c, d):
        if x > low and x < high:
            f_val = a*x**4 + b*x**3 + c*x**2 + d*x
            if f_val < f_opt:
                f_opt = f_val
                x_opt = x
    return x_opt

def cubic_roots_real(a, b, c, d):
    """
    Returns the real roots of the cubic equation a * x**3 + b * x**2 + c * x + d = 0
    """
    if math.fabs(a) < 1e-12 and math.fabs(b) < 1e-12 and math.fabs(c) < 1e-12:
        return []

    if math.fabs(a) < 1e-12 and math.fabs(b) < 1e-12:
        return [-d / c]

    elif math.fabs(a) < 1e-12:
        D = c * c - 4 * b * d
        if D >= 0:
            D = math.sqrt(D)
            x1 = (-c + D) / (2 * b)
            x2 = (-c - D) / (2 * b)
            return [x1, x2]
        else:
            D = math.sqrt(-D)
            return [-c / (2 * b)]

    f = ((3 * c / a) - ((b ** 2) / (a ** 2))) / 3
    g = (((2 * (b ** 3)) / (a ** 3)) - ((9 * b * c) / (a **2)) + (27 * d / a)) / 27
    h = (g ** 2) / 4 + (f ** 3) / 27

    if math.fabs(f) < 1e-12 and math.fabs(g) < 1e-12 and math.fabs(h) < 1e-12:
        if (d / a) >= 0:
            return [-((d / a) ** (1 / 3))]
        else:
            return [(-d / a) ** (1 / 3)]

    elif h <= 0:
        i = math.sqrt(((g ** 2) / 4) - h)
        j = i ** (1 / 3)
        k = math.acos(-(g / (2 * i)))
        L = -j
        M = math.cos(k / 3)
        N = math.sqrt(3) * math.sin(k / 3)
        P = -(b / (3 * a))
        x1 = 2 * j * math.cos(k / 3) - (b / (3 * a))
        x2 = L * (M + N) + P
        x3 = L * (M - N) + P
        return [x1, x2, x3]

    else:
        R = -(g / 2) + math.sqrt(h)
        if R >= 0:
            S = R ** (1 / 3)
        else:
            S = -((-R) ** (1 / 3))
        T = -(g / 2) - math.sqrt(h)
        if T >= 0:
            U = T ** (1 / 3)
        else:
            U = -((-T) ** (1 / 3))
        x1 = (S + U) - (b / (3 * a))
        x2 = -(S + U) / 2 - (b / (3 * a))
        return [x1, x2]