# An example that showcases how to use Jaxopt to solve the following bi-level optimization problem
# Hyperparameter optimization where the hyperparameter is the regularization parameter
# of a noisy inverse problem, where for each instance an optimization problem must be solved.
# argmin_alpha 1/2 ||x - est_x||_2^2 s.t. est_x = argmin_x 1/2 ||Ax - y||_2^2 + exp(alpha) ||x||_2^2
from functools import partial

from jaxopt import OptaxSolver
import jax.numpy as jnp
import optax

from id_in_practice.data import gen_l2_data
from id_in_practice.objectives import outer_objective_l2


def solve_l2(
    train_data,
    m_inner=50,
    n_inner=50,
    inner_lr=1e-2,
    outer_lr=1e-3,
    init_alpha=-6.0,
    n_outer=100,
    callback=None,
    diag=False,
    implicit_diff=True,
    hessian_inversion_solver="cg",
):
    outer_objective = partial(
        outer_objective_l2,
        m_inner=m_inner,
        n_inner=n_inner,
        inner_lr=inner_lr,
        implicit_diff=implicit_diff,
        hessian_inversion_solver=hessian_inversion_solver,
    )

    # now onto the optimization loop
    outer_solver = OptaxSolver(opt=optax.adam(outer_lr), fun=outer_objective, has_aux=True)
    if diag:
        dimension = train_data[0].shape[1]
        alpha = jnp.ones((1, dimension)) * init_alpha
    else:
        alpha = init_alpha
    state = outer_solver.init_state(alpha, data=train_data)

    for _ in range(n_outer):
        alpha, state = outer_solver.update(params=alpha, state=state, data=train_data)
        print(f"[Step {state.iter_num}] Validation loss: {state.value}.")
        if callback is not None:
            callback(alpha=alpha, state=state)
    return alpha, state


if __name__ == "__main__":
    # let's first define the data
    train_data, test_data = gen_l2_data(
        dimension=50,
        sigma_data=1.0,
        noise_level=1e-0,
        n_samples_train=100,
        n_samples_test=100,
    )

    # second, let's define the bi-level optimization problem'
    N_inner = 20
    M_inner = 50
    alpha, state = solve_l2(train_data, m_inner=M_inner, n_inner=N_inner)

    # evaluation
    outer_objective = partial(outer_objective_l2, m_inner=M_inner, n_inner=N_inner)
    print("Validation error", outer_objective(alpha, test_data)[0])
