# An example that showcases how to use Jaxopt to solve the following bi-level optimization problem
# Hyperparameter optimization where the hyperparameter is the regularization parameter
# of a noisy inverse problem, where for each instance an optimization problem must be solved.
# argmin_alpha 1/2 ||x - est_x||_2^2 s.t. est_x = argmin_x 1/2 ||Ax - y||_2^2 + exp(alpha) ||x||_2^2
from functools import partial

from jaxopt import OptaxSolver
import jax.numpy as jnp
import optax

from id_in_practice.data import gen_biquadratic_data
from id_in_practice.objectives import outer_objective_biquadratic


def solve_biquadratic(
    data,
    m_inner=50,
    n_inner=50,
    inner_lr=1e-2,
    outer_lr=1e-3,
    n_outer=100,
    theta_dim=1,
    callback=None,
    implicit_diff=True,
    hessian_inversion_solver="cg",
    z_0=None,
):
    outer_objective = partial(
        outer_objective_biquadratic,
        m_inner=m_inner,
        n_inner=n_inner,
        inner_lr=inner_lr,
        implicit_diff=implicit_diff,
        hessian_inversion_solver=hessian_inversion_solver,
        z_0=z_0,
    )

    # now onto the optimization loop
    outer_solver = OptaxSolver(opt=optax.adam(outer_lr), fun=outer_objective, has_aux=True)
    theta = jnp.zeros((theta_dim,))
    state = outer_solver.init_state(theta, data=data)

    for _ in range(n_outer):
        theta, state = outer_solver.update(params=theta, state=state, data=data)
        print(f"[Step {state.iter_num}] Outer loss: {state.value}.")
        if callback is not None:
            callback(theta=theta, state=state)
    return theta, state


if __name__ == "__main__":
    # let's first define the data
    theta_dim = 10
    data = gen_biquadratic_data(dimension_theta=theta_dim)
    # second, let's define the bi-level optimization problem'
    N_inner = 20
    inner_lr = 1e-2
    M_inner = 1  # not used
    alpha, state = solve_biquadratic(
        data,
        theta_dim=theta_dim,
        m_inner=M_inner,
        n_inner=N_inner,
        inner_lr=inner_lr,
        implicit_diff=False,
    )

    # evaluation
    outer_objective = partial(outer_objective_biquadratic, n_inner=N_inner, inner_lr=inner_lr)
    print("Validation error", outer_objective(alpha, data)[0])
