"""
A matrix inversion solver that uses Neumann iterates, and respects
the ``solve(matvec, b)`` format required by jaxopt.
This solver never materializes the matrix ``A`` in memory.
"""
from functools import partial
import operator

import jax
from jax import lax
import jax.numpy as jnp
from jax.tree_util import tree_leaves, tree_map

_add = partial(tree_map, operator.add)
_sub = partial(tree_map, operator.sub)
_dot_tree = partial(tree_map, jnp.dot)
_sqrt_tree = partial(tree_map, jnp.sqrt)


def _vdot_tree(x, y):
    return sum(tree_leaves(tree_map(partial(
        jnp.vdot, precision=lax.Precision.HIGHEST), x, y)))


def _div(tree, scalar):
    return tree_map(partial(lambda v: v / scalar), tree)


def _mul(scalar, tree):
    return tree_map(partial(operator.mul, scalar), tree)


def estimate_spectral_radius(matvec, shape, dtype, maxiter=1000):
    key = jax.random.PRNGKey(0)
    x = jax.random.normal(key, shape, dtype)
    x = _div(x, jnp.linalg.norm(x))

    def body_fun(k, x):
        x = matvec(x)
        x = _div(x, jnp.linalg.norm(x))
        return x

    x_final = lax.fori_loop(0, maxiter, body_fun, x)
    lambda_max = _vdot_tree(x_final, matvec(x_final))
    return lambda_max


def neumann_invert(matvec, b, maxiter=1000, atol=1e-5):
    def cond_fun(value):
        _, r, k = value
        rs = _sqrt_tree(_vdot_tree(r, r))
        return (rs > atol) & (k < maxiter) & jnp.logical_not(jnp.any(jnp.isnan(rs)))

    spectral_radius = estimate_spectral_radius(matvec, b.shape, b.dtype)  # 1/mu

    def body_fun(value):
        x, r, k = value
        # r is the last residual, i.e. (I-mu*A)^k b
        # x is the last iterate, i.e. sum_i^k (I-mu*A)^i b
        mu_A_r = _div(matvec(r), 0.6*spectral_radius)  # mu*A (I-mu*A)^k b
        z = _sub(r, mu_A_r)  # (I-mu*A)^{k+1} b
        x = _add(x, z)  # sum_i^k+1 (I-mu*A)^i b
        r = z
        return x, r, k + 1

    initial_value = (b, b, 0)
    x_final, r, k = lax.while_loop(cond_fun, body_fun, initial_value)
    x_final = _div(x_final, 0.6*spectral_radius)

    return x_final
