import warnings
from BACKEND import cp, sp

def power_it(w_init, max_it, A, eps, debug=False, debug_ctxt=""):
    w = w_old = w_init
    for it in range(max_it):
        w = A @ w_old
        w /= cp.linalg.norm(w)
        if cp.linalg.norm(w_old - w) < eps:
            if debug:
                print(f"Converged after {it} iterations.")
            break
        w_old = w
    else:
        warnings.warn(f"Power iteration did not converge: {debug_ctxt}", RuntimeWarning)
    return w

def inverse_power_it(w_init, max_it, A, shift, eps, debug=False, debug_ctxt=""):
    lup = sp.linalg.lu_factor(A - shift * cp.eye(A.shape[0], dtype=A.dtype), overwrite_a=True, check_finite=False)
    # Inverse power iteration
    w = w_old = w_init
    for it in range(max_it):
        w = sp.linalg.lu_solve(lup, w_old, overwrite_b=False, check_finite=False)
        w /= cp.linalg.norm(w)
        if cp.linalg.norm(w_old - w) < eps:
            if debug:
                print(f"Converged after {it} iterations.")
            break
        w_old = w
        if it == 3: # Reshift
            mu = w @ A @ w
            lup = sp.linalg.lu_factor(A - mu * cp.eye(A.shape[0], dtype=A.dtype), overwrite_a=True, check_finite=False)
    else:
        warnings.warn(f"Power iteration did not converge: {debug_ctxt}", RuntimeWarning)
    return w

def gerschgorin(A):
    Aa = cp.abs(A)
    diag = cp.diag(A)
    diaga = cp.diag(Aa)
    rows = Aa.sum(axis=0) - diaga
    cols = Aa.sum(axis=1) - diaga
    l_min = max(
        cp.min(diag - rows), cp.min(diag - cols)
    )
    l_max = min(
        cp.max(diag + rows), cp.max(diag + cols)
    )
    return l_min, l_max