"""Conjugate Residual Krylov solver."""

import warnings
from warnings import warn

import numpy as np
from scipy import sparse
from ..util.linalg import norm
from ..util import make_system


def cr(A, b, x0=None, tol=1e-5, criteria='rr',
       maxiter=None, M=None,
       callback=None, residuals=None):
    """Conjugate Residual algorithm.

    Solves the linear system Ax = b. Left preconditioning is supported.
    The matrix A must be Hermitian symmetric (but not necessarily definite).

    Parameters
    ----------
    A : array, matrix, sparse matrix, LinearOperator
        Linear system of size to solve.
    b : array, matrix
        Right hand side of size (n,) or (n,1).
    x0 : array, matrix
        Initial guess, default is a vector of zeros.
    tol : float
        Tolerance for stopping criteria.
    criteria : str
        Stopping criteria, let r=r_k, x=x_k

            'rr':        ||r||       < tol ||b||
            'rr+':       ||r||       < tol (||b|| + ||A||_F ||x||)
            'MrMr':      ||M r||     < tol ||M b||

        if ||b||=0, then set ||b||=1 for these tests.
    maxiter : int
        Maximum number of iterations allowed.
    M : array, matrix, sparse matrix, LinearOperator
        Inverted preconditioner of size (n,n), i.e. solve M A x = M b.
    callback : function
        User-supplied function is called after each iteration as
        ``callback(xk)``, where xk is the current solution vector.
    residuals : list
        Residual history in the 2-norm, including the initial residual.

    Returns
    -------
    array
        Updated guess after k iterations to the solution of Ax = b.
    int
        Halting status

            ==  =======================================
            0   successful exit
            >0  convergence to tolerance not achieved,
                return iteration count instead.
            <0  numerical breakdown, or illegal input
            ==  =======================================

    Notes
    -----
    The LinearOperator class is in scipy.sparse.linalg.
    Use this class if you prefer to define A or M as a mat-vec routine
    as opposed to explicitly constructing the matrix.

    References
    ----------
    .. [1] Yousef Saad, "Iterative Methods for Sparse Linear Systems,
       Second Edition", SIAM, pp. 262-67, 2003
       http://www-users.cs.umn.edu/~saad/books.html

    Examples
    --------
    >>> from pyamg.krylov import cr
    >>> from pyamg.util.linalg import norm
    >>> import numpy as np
    >>> from pyamg.gallery import poisson
    >>> A = poisson((10,10))
    >>> b = np.ones((A.shape[0],))
    >>> (x,flag) = cr(A,b, maxiter=2, tol=1e-8)
    >>> print(f'{norm(b - A@x):.6}')
    6.54282

    """
    A, M, x, b = make_system(A, M, x0, b)

    # Ensure that warnings are always reissued from this function
    warnings.filterwarnings('always', module='pyamg.krylov._cr')

    # determine maxiter
    if maxiter is None:
        maxiter = int(1.3*len(b)) + 2
    elif maxiter < 1:
        raise ValueError('Number of iterations must be positive')

    # setup method
    r = b - A @ x
    z = M @ r
    p = z.copy()
    zz = np.inner(z.conjugate(), z)

    normr = np.linalg.norm(r)

    if residuals is not None:
        residuals[:] = [normr]  # initial residual

    # Check initial guess if b != 0,
    normb = norm(b)
    if normb == 0.0:
        normb = 1.0  # reset so that tol is unscaled

    # set the stopping criteria (see the docstring)
    if criteria == 'rr':
        rtol = tol * normb
    elif criteria == 'rr+':
        if sparse.issparse(A.A):
            normA = norm(A.A.data)
        elif isinstance(A.A, np.ndarray):
            normA = norm(np.ravel(A.A))
        else:
            raise ValueError('Unable to use ||A||_F with the current matrix format.')
        rtol = tol * (normA * np.linalg.norm(x) + normb)
    elif criteria == 'MrMr':
        normr = np.sqrt(zz)
        normMb = norm(M @ b)
        rtol = tol * normMb
    else:
        raise ValueError('Invalid stopping criteria.')

    if normr < rtol:
        return (x, 0)

    # How often should r be recomputed
    recompute_r = 8

    Az = A @ z
    rAz = np.inner(r.conjugate(), Az)
    Ap = A @ p

    it = 0

    while True:

        rAz_old = rAz

        alpha = rAz / np.inner(Ap.conjugate(), Ap)  # 3
        x += alpha * p                              # 4

        if np.mod(it, recompute_r) and it > 0:      # 5
            r -= alpha * Ap
        else:
            r = b - A @ x

        z = M @ r

        Az = A @ z
        rAz = np.inner(r.conjugate(), Az)

        beta = rAz/rAz_old                        # 6

        p *= beta                                 # 7
        p += z

        Ap *= beta                                # 8
        Ap += Az

        it += 1

        zz = np.inner(z.conjugate(), z)

        normr = np.linalg.norm(r)

        if residuals is not None:
            residuals.append(normr)

        if callback is not None:
            callback(x)

        # set the stopping criteria (see the docstring)
        if criteria == 'rr':
            rtol = tol * normb
        elif criteria == 'rr+':
            rtol = tol * (normA * np.linalg.norm(x) + normb)
        elif criteria == 'MrMr':
            normr = norm(z)
            rtol = tol * normMb

        if normr < rtol:
            return (x, 0)

        if zz == 0.0:
            # rz == 0.0 is an indicator of convergence when r = 0.0
            warn('\nSingular preconditioner detected in CR, ceasing iterations\n')
            return (x, -1)

        if it == maxiter:
            return (x, it)
