"""
Tests for the Neumann iterates solver
"""
from functools import partial

import jax
import jax.numpy as jnp
import pytest

from id_in_practice.data import gen_l2_data_diag_correl
from id_in_practice.matrix_inversion import neumann_invert
from id_in_practice.resolution_l2 import solve_l2


@pytest.mark.parametrize("dimension", [1, 2, 5])
def test_neumann_invert(dimension):
    # we do the test with symmetric matrices since we are going to invert hessians
    key = jax.random.PRNGKey(0)
    key, *subkeys = jax.random.split(key, 3)
    A = jax.random.normal(subkeys[0], (dimension, dimension))
    A = A.T @ A
    x = jax.random.normal(subkeys[1], (dimension,))
    b = jnp.dot(A, x)
    matvec = partial(jnp.dot, A)
    x_neumann = neumann_invert(matvec, b, atol=1e-8, maxiter=10000000)
    assert jnp.allclose(x, x_neumann, atol=1e-7), "Neumann iterates failed to invert A"


def test_neumann_invert_in_l2_resolution():
    train_data, _ = gen_l2_data_diag_correl(
        dimension=10,
        sigma_data_range=(0.01, 0.2),
        noise_level=0.01,
        n_samples_train=100,
    )
    # solve with neumann
    alpha, state = solve_l2(
        train_data,
        m_inner=3000,
        n_inner=1000,
        inner_lr=1e0,
        n_outer=1000,
        outer_lr=1e-1,
        hessian_inversion_solver="neumann",
    )

    # solve with cg
    alpha_cg, state_cg = solve_l2(
        train_data,
        m_inner=100,
        n_inner=1000,
        inner_lr=1e0,
        n_outer=1000,
        outer_lr=1e-1,
        hessian_inversion_solver="cg",
    )

    assert jnp.isclose(alpha, alpha_cg, atol=1e-4), "alpha is not close"

