# in this file we are going to test whether the biquadratic resolution
# matches the theory
import jax
import jax.numpy as jnp
import lovely_jax as lt
import pytest

from id_in_practice.data import gen_biquadratic_data
from id_in_practice.objectives import outer_objective_biquadratic
from id_in_practice.resolution_biquadratic import solve_biquadratic


lt.monkey_patch()
jax.config.update("jax_enable_x64", True)


@pytest.mark.parametrize(
    "theta_dim, dimension_x, dimension_latent_inner, dimension_latent_outer",
    [
        (5, 5, 5, 5),
        (5, 5, 3, 2),
        (5, 5, 2, 3),
        (4, 5, 3, 2),
        (2, 5, 3, 2),
    ]
)
def test_solve_biquadratic(theta_dim, dimension_x, dimension_latent_inner, dimension_latent_outer):
    # let's first define the data
    *data, (sqrt_H, U, kappa, sqrt_G, omega) = gen_biquadratic_data(
        dimension_theta=theta_dim,
        dimension_x=dimension_x,
        dimension_latent_inner=dimension_latent_inner,
        dimension_latent_outer=dimension_latent_outer,
        return_factors=True,
    )
    # second, let's define the bi-level optimization problem'
    N_inner = 20
    inner_lr = 1e-2
    M_inner = 1  # not used
    theta_star_t, state = solve_biquadratic(
        data,
        theta_dim=theta_dim,
        m_inner=M_inner,
        n_inner=N_inner,
        inner_lr=inner_lr,
        implicit_diff=False,
        outer_lr=1e0,
        n_outer=400,
    )

    # evaluation
    assert state.error < 1e-3

    # Let's check that the inner iterate has the correct form
    # zt(θ) = K>(Pt(  ̄H) − I)  ̄H−1U θ + K>(Pt(  ̄H) − I)  ̄H−1κ + Pt(H)z0
    H_bar = sqrt_H @ sqrt_H.T
    P_t = jnp.linalg.matrix_power(jnp.eye(dimension_latent_inner) - inner_lr * H_bar, N_inner)
    E_t = (P_t - jnp.eye(dimension_latent_inner)) @ jnp.linalg.inv(H_bar)
    z_t = sqrt_H.T @ E_t @ (U @ theta_star_t + kappa)
    _, (z_t_iterative, _) = outer_objective_biquadratic(
        theta_star_t,
        data,
        n_inner=N_inner,
        m_inner=1,  # not used
        inner_lr=inner_lr,
        implicit_diff=False,
    )
    assert jnp.allclose(z_t, z_t_iterative, rtol=1e-5, atol=1e-5)

    # theoretically, we have ΓK>EtU θ?,t = −P (ΓK>EtU )(Γrt − z?)
    # with E_t = (Pt(  ̄H) − I)  ̄H−1
    # and rt = K>(Pt(  ̄H) − I)  ̄H−1κ + Pt(H)z0
    # z* = - omega
    C = sqrt_G @ sqrt_H.T
    r_t = sqrt_H.T @ E_t @ kappa
    proj_CEU = C @ E_t @ U @ jnp.linalg.pinv(C @ E_t @ U)
    first_term = C @ E_t @ U @ theta_star_t
    second_term = - proj_CEU @ (sqrt_G @ r_t + omega)
    assert jnp.allclose(first_term, second_term, rtol=1e-5, atol=1e-5)


@pytest.mark.parametrize(
    "theta_dim, dimension_x, dimension_latent_inner, dimension_latent_outer",
    [
        (5, 5, 5, 5),
        (5, 5, 3, 2),
        (5, 5, 2, 3),
        (4, 5, 3, 2),
        (2, 5, 3, 2),
    ]
)
@pytest.mark.parametrize("w_init", [True, False])
def test_loss_value_w_diff_inner_time(theta_dim, dimension_x, dimension_latent_inner, dimension_latent_outer, w_init):
    # let's first define the data
    *data, (sqrt_H, U, kappa, sqrt_G, omega) = gen_biquadratic_data(
        dimension_theta=theta_dim,
        dimension_x=dimension_x,
        dimension_latent_inner=dimension_latent_inner,
        dimension_latent_outer=dimension_latent_outer,
        return_factors=True,
    )
    if w_init:
        key = jax.random.PRNGKey(0)
        z_0 = jax.random.normal(key, (dimension_x,))
    else:
        z_0 = None
    # second, let's define the bi-level optimization problem'
    N_inner = 20
    inner_lr = 1e-2
    M_inner = 1  # not used
    theta_star_t, state = solve_biquadratic(
        data,
        theta_dim=theta_dim,
        m_inner=M_inner,
        n_inner=N_inner,
        inner_lr=inner_lr,
        implicit_diff=False,
        outer_lr=1e0,
        n_outer=200,
        z_0=z_0,
    )

    # evaluation
    assert state.error < 1e-3

    # Let's look at the loss value for a different inner optimization time
    N_inner_new = 10  # t'
    new_loss, (z_t_prime, _) = outer_objective_biquadratic(
        theta_star_t,
        data,
        n_inner=N_inner_new,
        m_inner=1,  # not used
        inner_lr=inner_lr,
        implicit_diff=False,
        z_0=z_0,
    )

    # now we build the matrices we need
    G, w, H, B, c = data
    H_bar = sqrt_H @ sqrt_H.T
    P_t = jnp.linalg.matrix_power(jnp.eye(dimension_latent_inner) - inner_lr * H_bar, N_inner)
    P_t_H = jnp.linalg.matrix_power(jnp.eye(dimension_x) - inner_lr * H, N_inner)
    P_t_prime = jnp.linalg.matrix_power(jnp.eye(dimension_latent_inner) - inner_lr * H_bar, N_inner_new)
    E_t = (P_t - jnp.eye(dimension_latent_inner)) @ jnp.linalg.inv(H_bar)
    C = sqrt_G @ sqrt_H.T
    M = (P_t_prime - jnp.eye(dimension_latent_inner)) @ jnp.linalg.inv(P_t - jnp.eye(dimension_latent_inner))
    proj_CEU = C @ E_t @ U @ jnp.linalg.pinv(C @ E_t @ U)
    proj_C = C @ jnp.linalg.pinv(C)
    proj_K_T = sqrt_H.T @ jnp.linalg.pinv(sqrt_H.T)
    proj_C_transpose = C.T @ jnp.linalg.pinv(C.T)
    proj_C_perp = jnp.eye(dimension_latent_outer) - proj_C

    r_t = sqrt_H.T @ E_t @ kappa
    if z_0 is not None:
        r_t += P_t_H @ z_0
    theta_ker_C = E_t @ U @ theta_star_t + jnp.linalg.pinv(C) @ proj_CEU @ (sqrt_G @ r_t + omega)
    assert jnp.allclose(C@theta_ker_C, jnp.zeros_like(C@theta_ker_C), rtol=1e-4, atol=1e-4)
    omega_proj = jnp.linalg.pinv(C) @ proj_C @ omega
    first_vector_inside = proj_C_transpose @ (E_t @ kappa + omega_proj)
    first_vector_inside -= E_t @ kappa
    if z_0 is not None:
        first_vector_inside += jnp.linalg.pinv(C)@sqrt_G@P_t_H@z_0
        xi_0 = jnp.linalg.pinv(sqrt_H.T) @ proj_K_T @ z_0
        first_vector_inside -= (P_t - jnp.eye(dimension_latent_inner)) @ xi_0
    first_vector_inside -= theta_ker_C
    first_vector_inside = (jnp.eye(dimension_latent_inner) - M) @ first_vector_inside

    second_vector_inside = C @ E_t @ kappa + C @ omega_proj
    if z_0 is not None:
        second_vector_inside += sqrt_G @ P_t_H @ z_0
    second_vector_inside = M @ jnp.linalg.pinv(C) @ (jnp.eye(dimension_latent_outer) - proj_CEU) @ second_vector_inside

    vector_inside = first_vector_inside + second_vector_inside
    vector_inside = C @ vector_inside

    first_part_of_loss = 0.5 * jnp.linalg.norm(vector_inside) ** 2

    # second part of the loss

    omega_orth_C = proj_C_perp @ omega
    if z_0 is not None:
        omega_orth_C += proj_C_perp @ sqrt_G @ P_t_H @ z_0
    second_part_of_loss = 0.5 * jnp.linalg.norm(omega_orth_C) ** 2 - 0.5 * jnp.linalg.norm(omega) ** 2

    loss = first_part_of_loss + second_part_of_loss
    assert jnp.allclose(new_loss, loss, rtol=1e-4, atol=1e-4)
