from functools import partial

import jax
import jax.numpy as jnp
from jaxopt import GradientDescent, ProximalGradient
from jaxopt import linear_solve
from jaxopt import prox

from id_in_practice.matrix_inversion import neumann_invert


@jax.jit
def data_fidelity_l2_reg(x, lam, data):
    A, y = data
    df = data_fidelity(x, data)
    l2_reg = 0.5 * jnp.sum(lam * x**2)
    return df + l2_reg


@partial(jax.jit, static_argnames=("n_inner", "m_inner", "inner_lr", "implicit_diff", "hessian_inversion_solver"))
def outer_objective_l2(
    alpha,
    data,
    n_inner=50,
    m_inner=50,
    inner_lr=1e-0,
    implicit_diff=True,
    hessian_inversion_solver="cg",
):
    iterative_solver = GradientDescent
    inner_objective = data_fidelity_l2_reg
    return outer_objective(
        alpha,
        data,
        iterative_solver=iterative_solver,
        inner_objective=inner_objective,
        n_inner=n_inner,
        m_inner=m_inner,
        inner_lr=inner_lr,
        implicit_diff=implicit_diff,
        hessian_inversion_solver=hessian_inversion_solver,
    )


@jax.jit
def data_fidelity(x, data):
    A, y = data
    return 0.5 * jnp.mean((jnp.dot(A, x.T).T - y)**2)


@partial(jax.jit, static_argnames=("n_inner", "m_inner", "inner_lr", "implicit_diff"))
def outer_objective_l1(
    alpha,
    data,
    n_inner=50,
    m_inner=50,
    inner_lr=1e-0,
    implicit_diff=True,
):
    iterative_solver = partial(ProximalGradient, prox=prox.prox_lasso)
    inner_objective = data_fidelity
    return outer_objective(
        alpha,
        data,
        iterative_solver=iterative_solver,
        inner_objective=inner_objective,
        n_inner=n_inner,
        m_inner=m_inner,
        inner_lr=inner_lr,
        implicit_diff=implicit_diff,
    )


def outer_objective(
    alpha,
    data,
    iterative_solver,
    inner_objective,
    n_inner=50,
    m_inner=50,
    inner_lr=1e-0,
    implicit_diff=True,
    hessian_inversion_solver="cg",
):
    x_train, A, y_train = data
    lam = jnp.exp(alpha)
    if hessian_inversion_solver == "cg":
        implicit_diff_solve = partial(linear_solve.solve_cg, maxiter=m_inner, tol=1e-7)
    elif hessian_inversion_solver == "neumann":
        implicit_diff_solve = partial(neumann_invert, maxiter=m_inner, atol=1e-7)
    else:
        raise ValueError("Unknown hessian_inversion_solver")
    solver = iterative_solver(
        fun=inner_objective,
        implicit_diff=implicit_diff,
        implicit_diff_solve=implicit_diff_solve,
        maxiter=n_inner,
        acceleration=False,
        stepsize=inner_lr,
        tol=1e-7,
    )
    x_train_est, _inner_state = solver.run(jnp.zeros_like(x_train), lam, (A, y_train))
    loss_value = 0.5*jnp.mean((x_train - x_train_est)**2)
    return loss_value, (x_train_est, _inner_state)


# very simple quadratic case functions

@jax.jit
def quadratic_loss(z, H, c):
    loss = 0.5 * z.T @ (H@z) + z.T@c
    return loss


@jax.jit
def quadratic_loss_bilevel(z, theta, data):
    # data contains in this order H, B, c
    # the loss will be 1/2 z^T H z + z^T B theta + z^T c
    H, B, c = data
    loss = quadratic_loss(
        z,
        H,
        c + B @ theta,
    )
    return loss


def outer_objective_quadratic(
    theta,
    data,
    z_0=None,
    n_inner=50,
    m_inner=50,
    inner_lr=1e-0,
    implicit_diff=True,
    hessian_inversion_solver="cg",
):
    x, H, B, c = data
    if hessian_inversion_solver == "cg":
        implicit_diff_solve = partial(linear_solve.solve_cg, maxiter=m_inner, tol=1e-7)
    elif hessian_inversion_solver == "neumann":
        implicit_diff_solve = partial(neumann_invert, maxiter=m_inner, atol=1e-7)
    else:
        raise ValueError("Unknown hessian_inversion_solver")
    solver = GradientDescent(
        fun=quadratic_loss_bilevel,
        implicit_diff=implicit_diff,
        implicit_diff_solve=implicit_diff_solve,
        maxiter=n_inner,
        acceleration=False,
        stepsize=inner_lr,
        tol=1e-7,
    )
    if z_0 is None:
        z_0 = jnp.zeros_like(x)
    z, _inner_state = solver.run(z_0, theta, (H, B, c))
    loss_value = 0.5*jnp.mean((x - z)**2)
    return loss_value, (z, _inner_state)


def outer_objective_biquadratic(
    theta,
    data,
    n_inner=50,
    m_inner=50,
    inner_lr=1e-0,
    implicit_diff=True,
    hessian_inversion_solver="cg",
    z_0=None,
):
    G, w, H, B, c = data
    if hessian_inversion_solver == "cg":
        implicit_diff_solve = partial(linear_solve.solve_cg, maxiter=m_inner, tol=1e-7)
    elif hessian_inversion_solver == "neumann":
        implicit_diff_solve = partial(neumann_invert, maxiter=m_inner, atol=1e-7)
    else:
        raise ValueError("Unknown hessian_inversion_solver")
    solver = GradientDescent(
        fun=quadratic_loss_bilevel,
        implicit_diff=implicit_diff,
        implicit_diff_solve=implicit_diff_solve,
        maxiter=n_inner,
        acceleration=False,
        stepsize=inner_lr,
        tol=1e-7,
    )
    dimension_z = G.shape[0]
    if z_0 is None:
        z_0 = jnp.zeros((dimension_z,))
    z, _inner_state = solver.run(z_0, theta, (H, B, c))
    loss_value = quadratic_loss(z, G, w)
    return loss_value, (z, _inner_state)
