"""Public API for Metropolis Adjusted Langevin kernels."""
import operator
from typing import Callable, NamedTuple, Tuple

import jax
import jax.numpy as jnp

import bblackjax.mcmc.diffusions as diffusions
import bblackjax.mcmc.proposal as proposal
from bblackjax.base import SamplingAlgorithm
from bblackjax.types import PRNGKey, PyTree

__all__ = ["MALAState", "MALAInfo", "init", "build_kernel", "mala"]


class MALAState(NamedTuple):
    """State of the MALA algorithm.

    The MALA algorithm takes one position of the chain and returns another
    position. In order to make computations more efficient, we also store
    the current log-probability density as well as the current gradient of the
    log-probability density.

    """

    position: PyTree
    logdensity: float
    logdensity_grad: PyTree


class MALAInfo(NamedTuple):
    """Additional information on the MALA transition.

    This additional information can be used for debugging or computing
    diagnostics.

    acceptance_rate
        The acceptance rate of the transition.
    is_accepted
        Whether the proposed position was accepted or the original position
        was returned.

    """

    acceptance_rate: float
    is_accepted: bool
    proposed_position: PyTree
    proposed_weight: float


def init(position: PyTree, logdensity_fn: Callable) -> MALAState:
    grad_fn = jax.value_and_grad(logdensity_fn)
    logdensity, logdensity_grad = grad_fn(position)
    return MALAState(position, logdensity, logdensity_grad)


def build_kernel():
    """Build a MALA kernel.

    Returns
    -------
    A kernel that takes a rng_key and a Pytree that contains the current state
    of the chain and that returns a new state of the chain along with
    information about the transition.

    """

    def transition_energy(state, new_state, step_size):
        """Transition energy to go from `state` to `new_state`"""
        theta = jax.tree_util.tree_map(
            lambda new_x, x, g: new_x - x - step_size * g,
            new_state.position,
            state.position,
            state.logdensity_grad,
        )
        theta_dot = jax.tree_util.tree_reduce(
            operator.add, jax.tree_util.tree_map(lambda x: jnp.sum(x * x), theta)
        )
        return -state.logdensity + 0.25 * (1.0 / step_size) * theta_dot

    init_proposal, generate_proposal = proposal.asymmetric_proposal_generator(
        transition_energy, divergence_threshold=jnp.inf
    )
    sample_proposal = proposal.static_binomial_sampling

    def kernel(
        rng_key: PRNGKey, state: MALAState, logdensity_fn: Callable, step_size: float
    ) -> Tuple[MALAState, MALAInfo]:
        """Generate a new sample with the MALA kernel."""
        grad_fn = jax.value_and_grad(logdensity_fn)
        integrator = diffusions.overdamped_langevin(grad_fn)

        key_integrator, key_rmh = jax.random.split(rng_key)

        new_state = integrator(key_integrator, state, step_size)
        new_state = MALAState(*new_state)

        proposal = init_proposal(state)
        new_proposal, _ = generate_proposal(state, new_state, step_size=step_size)
        sampled_proposal, do_accept, p_accept = sample_proposal(
            key_rmh, proposal, new_proposal
        )

        theta = jax.tree_util.tree_map(
            lambda new_x, x, g: new_x - x - step_size * g,
            state.position,
            new_state.position,
            new_state.logdensity_grad,
        )
        theta_dot = jax.tree_util.tree_reduce(
            operator.add, jax.tree_util.tree_map(lambda x: jnp.sum(x * x), theta)
        )
        proposed_weight = jnp.exp(new_state.logdensity + 0.25 * (1.0 / step_size) * theta_dot)
        proposed_position = new_state.position

        info = MALAInfo(p_accept, do_accept, proposed_position, proposed_weight)

        return sampled_proposal.state, info

    return kernel


class mala:
    """Implements the (basic) user interface for the MALA kernel.

    The general mala kernel builder (:meth:`blackjax.mcmc.mala.build_kernel`, alias `blackjax.mala.build_kernel`) can be
    cumbersome to manipulate. Since most users only need to specify the kernel
    parameters at initialization time, we provide a helper function that
    specializes the general kernel.

    We also add the general kernel and state generator as an attribute to this class so
    users only need to pass `blackjax.mala` to SMC, adaptation, etc. algorithms.

    Examples
    --------

    A new MALA kernel can be initialized and used with the following code:

    .. code::

        mala = blackjax.mala(logdensity_fn, step_size)
        state = mala.init(position)
        new_state, info = mala.step(rng_key, state)

    Kernels are not jit-compiled by default so you will need to do it manually:

    .. code::

       step = jax.jit(mala.step)
       new_state, info = step(rng_key, state)

    Should you need to you can always use the base kernel directly:

    .. code::

       kernel = blackjax.mala.build_kernel(logdensity_fn)
       state = blackjax.mala.init(position, logdensity_fn)
       state, info = kernel(rng_key, state, logdensity_fn, step_size)

    Parameters
    ----------
    logdensity_fn
        The log-density function we wish to draw samples from.
    step_size
        The value to use for the step size in the symplectic integrator.

    Returns
    -------
    A ``MCMCSamplingAlgorithm``.

    """

    init = staticmethod(init)
    build_kernel = staticmethod(build_kernel)

    def __new__(  # type: ignore[misc]
        cls,
        logdensity_fn: Callable,
        step_size: float,
    ) -> SamplingAlgorithm:
        kernel = cls.build_kernel()

        def init_fn(position: PyTree):
            return cls.init(position, logdensity_fn)

        def step_fn(rng_key: PRNGKey, state):
            return kernel(rng_key, state, logdensity_fn, step_size)

        return SamplingAlgorithm(init_fn, step_fn)