from functools import partial

import jax
import jax.numpy as jnp
import pytest

from id_in_practice.data import gen_l1_data, gen_l2_data, gen_l2_data_diag_correl
from id_in_practice.grid_search import grid_search_1d
from id_in_practice.objectives import outer_objective_l1, outer_objective_l2
from id_in_practice.resolution_l1 import solve_l1
from id_in_practice.resolution_l2 import solve_l2


def test_convergence_inner_l1():
    train_data, _ = gen_l1_data(
        dimension=10,
        sparsity_level=0.1,
        noise_level=0.1,
        n_samples_train=100,
    )
    alpha = -4.
    inner_lr = 1e1

    _, (_, inner_state) = outer_objective_l1(
        alpha,
        train_data,
        m_inner=100,
        n_inner=1000,
        inner_lr=inner_lr,
        implicit_diff=False,
    )
    error = inner_state.error
    assert error < 1e-6, "inner problem did not converge"


def test_convergence_inner_l2():
    train_data, _ = gen_l2_data(
        dimension=10,
        sigma_data=0.1,
        noise_level=0.1,
        n_samples_train=100,
    )
    alpha = -4.
    inner_lr = 1e0

    _, (_, inner_state) = outer_objective_l2(
        alpha,
        train_data,
        m_inner=100,
        n_inner=1000,
        inner_lr=inner_lr,
        implicit_diff=False,
    )
    error = inner_state.error
    assert error < 1e-6, "inner problem did not converge"


def test_grad_id_l1():
    train_data, _ = gen_l1_data(
        dimension=10,
        sparsity_level=0.1,
        noise_level=0.1,
        n_samples_train=100,
    )
    alpha = -4.
    inner_lr = 1e1

    def objective_fun(alpha, implicit_diff=True):
        return outer_objective_l1(
            alpha,
            train_data,
            m_inner=100,
            n_inner=1000,
            inner_lr=inner_lr,
            implicit_diff=implicit_diff,
        )[0]

    objective_fun_id = partial(objective_fun, implicit_diff=True)
    objective_fun_unrolled = partial(objective_fun, implicit_diff=False)

    grad_id = jax.grad(objective_fun_id)(alpha)
    grad_unrolled = jax.grad(objective_fun_unrolled)(alpha)

    assert jnp.isclose(grad_id, grad_unrolled, atol=1e-5), "grads are not close"


def _test_resolution(
    data_gen_fun,
    solve_fun,
    outer_objective_fun,
    n_inner=1000,
    inner_lr=1e1,
    n_outer=10000,
    outer_lr=1e-1,
    no_grid=False,
    **solver_kwargs,
):

    M_inner = 100

    train_data, _ = data_gen_fun()

    value_trace = []

    def log_value_trace_cback(alpha, state):
        value_trace.append(state.value)

    # resolution with HOAG
    alpha, state = solve_fun(
        train_data,
        m_inner=M_inner,
        n_inner=n_inner,
        inner_lr=inner_lr,
        outer_lr=outer_lr,
        n_outer=n_outer,
        callback=log_value_trace_cback,
        **solver_kwargs,
    )

    # let's get the same result with grid search
    if not no_grid:
        def objective_fun(alpha, data):
            return outer_objective_fun(
                alpha,
                data,
                m_inner=M_inner,
                n_inner=n_inner,
                inner_lr=inner_lr,
            )[0]

        alpha_gs, error_gs = grid_search_1d(
            objective_fun,
            lookup_range=jnp.linspace(-10, 0, 200),
            data=train_data,
        )

        # now onto the assertions
        assert jnp.isclose(error_gs, state.value, atol=1e-5) or state.value <= error_gs, \
            "error of GS is not that of HOAG"
        assert jnp.isclose(jnp.exp(alpha), jnp.exp(alpha_gs), atol=1e-1), "alpha is not close"


@pytest.mark.parametrize("dimension, noise_level", [
    (50, 1e-0),
    (50, 1e-1),
    (10, 1e-0),
])
def test_l1_resolution(dimension, noise_level):
    data_gen_function = partial(
        gen_l1_data,
        dimension=dimension,
        sparsity_level=0.1,
        noise_level=noise_level,
        n_samples_train=100,
    )
    _test_resolution(
        data_gen_function,
        solve_l1,
        outer_objective_l1,
        outer_lr=1e0,
    )


@pytest.mark.parametrize("dimension, noise_level", [
    (50, 1e-0),
    (50, 1e-2),
    (10, 1e-0),
])
def test_l2_resolution(dimension, noise_level):
    data_gen_function = partial(
        gen_l2_data,
        dimension=dimension,
        sigma_data=0.1,
        noise_level=noise_level,
        n_samples_train=100,
    )
    _test_resolution(
        data_gen_function,
        solve_l2,
        outer_objective_l2,
        n_inner=1000,
        inner_lr=1e0,
        n_outer=1000,
        outer_lr=1e-1,
    )


def test_l2_resolution_diag_call():
    data_gen_function = partial(
        gen_l2_data_diag_correl,
        dimension=10,
        sigma_data_range=(0.01, 0.2),
        noise_level=0.01,
        n_samples_train=100,
    )
    _test_resolution(
        data_gen_function,
        solve_l2,
        outer_objective_l2,
        n_inner=1000,
        inner_lr=1e0,
        n_outer=1000,
        outer_lr=1e-1,
        diag=True,
        no_grid=True,
    )
