import functools

import numpy as np

import jax
import jax.numpy as jnp
from jax.scipy import linalg

import jaxopt
from jaxopt import linear_solve
from jaxopt import tree_util

import rlax

from acme.jax import utils as acme_utils


def hvp(f, primals, tangents, has_aux=False):
    return jax.jvp(jax.grad(f, has_aux=has_aux), primals, tangents, has_aux=has_aux)[1]


def vec_td_learning(online_params, target_params, apply_fn, o_tm1, r_t, discount_t, o_t):
    v_tm1 = apply_fn(online_params, o_tm1)
    v_t = apply_fn(target_params, o_t)

    return jax.vmap(rlax.td_learning)(v_tm1[..., 0], r_t, discount_t, v_t[..., 0])


def mean_td_learning(online_params, target_params, apply_fn, o_tm1, r_t, discount_t, o_t):
    errors = vec_td_learning(online_params, target_params, apply_fn, o_tm1, r_t, discount_t, o_t)
    l = 0.5*jnp.mean(errors**2)
    return l, {"td_error": l}


def frm_td_loss(online_params,
                target_params,
                apply_fn,
                o_tm1,
                r_t,
                discount_t,
                o_t,
                ridge_coeff,
                linear_solver,
                hess_average_loss=True,
                use_iterative_refinement=False,
                tol=1e-5):

    def _apply_single(p, x):
        return acme_utils.squeeze_batch_dim(
            apply_fn(p, acme_utils.add_batch_dim(x)))

    def _amat_func(params_obs, v):
        params, obs = params_obs

        def loss_fn(p):
            loss = apply_fn(p, obs) - apply_fn(params, obs)
            agg_func = jnp.mean if hess_average_loss else jnp.sum
            return 0.5 * agg_func(loss ** 2)

        return hvp(loss_fn, (params,), (v,))

    def _jac_hess_inv_jac(obs):
        jac = jax.jacrev(lambda p: _apply_single(p, obs).squeeze())(online_params)

        solver_tol = tol*(1e-1 if use_iterative_refinement else 1.)
        solver = getattr(linear_solve, f"solve_{linear_solver}")
        solver = functools.partial(solver, ridge=ridge_coeff, tol=solver_tol)

        params_obs = (online_params, o_tm1)
        if use_iterative_refinement:
            solver = jaxopt.IterativeRefinement(matvec_A=_amat_func, solve=solver, tol=tol)
            hess_inv_jac, _ = solver.run(init_params=None, A=params_obs, b=jac)
        else:
            hess_inv_jac = solver(
                lambda v: _amat_func(params_obs, v),
                jac,
            )

        # TODO: check if using vjp is faster
        return tree_util.tree_vdot(jac, hess_inv_jac)
    td_error = vec_td_learning(online_params, target_params, apply_fn, o_tm1, r_t, discount_t, o_t)
    jHinvj = jax.lax.map(_jac_hess_inv_jac, o_tm1)
    frm_loss = jnp.mean(0.5*td_error**2/jHinvj + jnp.log(jHinvj))

    return frm_loss, {"mean_td_error": jnp.mean(td_error**2), "mean_frm_loss": frm_loss}


def dense_linear_frm_td_loss(online_params,
                             target_params,
                             apply_fn,
                             o_tm1,
                             r_t,
                             discount_t,
                             o_t,
                             ridge_coeff,
                             hessian_inputs=None,
                             hess_average_loss=True):

    if hessian_inputs is None:
        hessian_inputs = o_tm1

    def loss_fn(p):
        loss = apply_fn(p, hessian_inputs) - apply_fn(online_params, hessian_inputs)
        agg_func = jnp.mean if hess_average_loss else jnp.sum
        return 0.5*agg_func(loss**2)

    def _apply_single(p, x):
        return acme_utils.squeeze_batch_dim(
            apply_fn(p, acme_utils.add_batch_dim(x)))

    with jax.ensure_compile_time_eval():
        hess = jax.hessian(loss_fn)(online_params)
        hess = jax.tree_util.tree_leaves(hess)

        assert len(hess) == 1
        hess = hess[0].squeeze()
        assert hess.ndim == 2

        hess_chol = linalg.cho_factor(hess + ridge_coeff * jnp.eye(hess.shape[1]))

    def _jac_hess_inv_jac(obs):
        jac = jax.jacrev(lambda p: _apply_single(p, obs).squeeze())(online_params)
        jac = jax.tree_util.tree_leaves(jac)
        assert len(jac) == 1
        jac = jac[0]

        hess_inv_jac = linalg.cho_solve(hess_chol, jac)
        return tree_util.tree_vdot(jac, hess_inv_jac)

    td_error = vec_td_learning(online_params, target_params, apply_fn, o_tm1, r_t, discount_t, o_t)
    jHinvj = jax.vmap(_jac_hess_inv_jac)(o_tm1)
    frm_loss = jnp.mean(0.5 * td_error ** 2 / jHinvj + jnp.log(jHinvj))

    return frm_loss, {"mean_td_error": jnp.mean(td_error ** 2), "mean_frm_loss": frm_loss}
