"""Public API for ChEES-HMC"""

from functools import partial
from typing import Callable, NamedTuple, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import optax

import blackjax.mcmc.dynamic_hmc as dynamic_hmc
import blackjax.optimizers.dual_averaging as dual_averaging
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
from blackjax.base import AdaptationAlgorithm
from blackjax.types import Array, ArrayLikeTree, PRNGKey
from blackjax.util import pytree_size

# optimal tuning for HMC, see https://arxiv.org/abs/1001.4460
OPTIMAL_TARGET_ACCEPTANCE_RATE = 0.651


class ChEESAdaptationState(NamedTuple):
    """State of the ChEES-HMC adaptation scheme.

    step_size
        Value of the step_size parameter of the HMC algorithm.
    log_step_size_moving_average
        Running moving average of the log step_size parameter.
    trajectory_length
        Value of the num_integration_steps * step_size parameter of
        the HMC algorithm.
    log_trajectory_length_moving_average
        Running moving average of the log num_integration_steps / step_size
        parameter.
    optim_state
        Optax optimizing state for used to maximize the ChEES criterion.
    random_generator_arg
        Utility array for generating a pseudo or quasi-random sequence of
        numbers.
    step
        Current iteration number.

    """

    step_size: float
    log_step_size_moving_average: float
    trajectory_length: float
    log_trajectory_length_moving_average: float
    da_state: dual_averaging.DualAveragingState
    optim_state: optax.OptState
    random_generator_arg: Array
    step: int


def base(
    jitter_generator: Callable,
    next_random_arg_fn: Callable,
    optim: optax.GradientTransformation,
    target_acceptance_rate: float,
    decay_rate: float,
) -> Tuple[Callable, Callable]:
    """Maximizing the Change in the Estimator of the Expected Square criterion
    (trajectory length) and dual averaging procedure (step size) for the jittered
    Hamiltonian Monte Carlo kernel :cite:p:`hoffman2021adaptive`.

    This adaptation algorithm tunes the step size and trajectory length, i.e.
    number of integration steps / step size, of the jittered HMC algorithm based
    on statistics collected from a population of many chains. It maximizes the Change
    in the Estimator of the Expected Square (ChEES) criterion to tune the trajectory
    length and uses dual averaging targeting an acceptance rate of 0.651 of the harmonic
    mean of the chain's acceptance probabilities to tune the step size.

    Parameters
    ----------
    jitter_generator
        Optional function that generates a value in [0, 1] used to jitter the trajectory
        lengths given a PRNGKey, used to propose the number of integration steps. If None,
        then a quasi-random Halton is used to jitter the trajectory length.
    next_random_arg_fn
        Function that generates the next `random_generator_arg` from its previous value.
    optim
        Optax compatible optimizer, which conforms to the `optax.GradientTransformation` protocol.
    target_acceptance_rate
        Average acceptance rate to target with dual averaging.
    decay_rate
        Float representing how much to favor recent iterations over earlier ones in the optimization
        of step size and trajectory length.


    Returns
    -------
    init
        Function that initializes the warmup.
    update
        Function that moves the warmup one step.

    """

    da_init, da_update, _ = dual_averaging.dual_averaging()

    def compute_parameters(
        proposed_positions: ArrayLikeTree,
        proposed_momentums: ArrayLikeTree,
        initial_positions: ArrayLikeTree,
        acceptance_probabilities: Array,
        is_divergent: Array,
        initial_adaptation_state: ChEESAdaptationState,
    ) -> ChEESAdaptationState:
        """Compute values for the parameters based on statistics collected from
        multiple chains.

        Parameters
        ----------
        proposed_positions:
            A PyTree that contains the position proposed by the HMC algorithm of
            every chain (proposal that is accepted or rejected using MH).
        proposed_momentums:
            A PyTree that contains the momentum variable proposed by the HMC algorithm
            of every chain (proposal that is accepted or rejected using MH).
        initial_positions:
            A PyTree that contains the initial position at the start of the HMC
            algorithm of every chain.
        acceptance_probabilities:
            Metropolis-Hastings acceptance probabilty of proposals of every chain.
        initial_adaptation_state:
            ChEES adaptation step used to generate proposals and acceptance probabilities.

        Returns
        -------
        New values of the step size and trajectory length of the jittered HMC algorithm.

        """
        (
            step_size,
            log_step_size_ma,
            trajectory_length,
            log_trajectory_length_ma,
            da_state,
            optim_state,
            random_generator_arg,
            step,
        ) = initial_adaptation_state

        harmonic_mean = 1.0 / jnp.mean(
            1.0 / acceptance_probabilities, where=~is_divergent
        )
        da_state_ = da_update(da_state, target_acceptance_rate - harmonic_mean)
        step_size_ = jnp.exp(da_state_.log_x)
        new_step_size, new_da_state, new_log_step_size = jax.lax.cond(
            jnp.isfinite(step_size_),
            lambda _: (step_size_, da_state_, da_state_.log_x),
            lambda _: (step_size, da_state, da_state.log_x),
            None,
        )
        update_weight = step ** (-decay_rate)
        new_log_step_size_ma = (
            1.0 - update_weight
        ) * log_step_size_ma + update_weight * new_log_step_size

        proposals_mean = jax.tree_util.tree_map(
            lambda p: jnp.nanmean(p, axis=0), proposed_positions
        )
        initials_mean = jax.tree_util.tree_map(
            lambda p: jnp.nanmean(p, axis=0), initial_positions
        )
        proposals_centered = jax.tree_util.tree_map(
            lambda p, pm: p - pm, proposed_positions, proposals_mean
        )
        initials_centered = jax.tree_util.tree_map(
            lambda p, pm: p - pm, initial_positions, initials_mean
        )

        vmap_flatten_op = jax.vmap(lambda p: jax.flatten_util.ravel_pytree(p)[0])
        proposals_matrix = vmap_flatten_op(proposals_centered)
        initials_matrix = vmap_flatten_op(initials_centered)
        momentums_matrix = vmap_flatten_op(proposed_momentums)

        trajectory_gradients = (
            jitter_generator(random_generator_arg)
            * trajectory_length
            * jax.vmap(
                lambda pm, im, mm: (jnp.dot(pm, pm) - jnp.dot(im, im)) * jnp.dot(pm, mm)
            )(proposals_matrix, initials_matrix, momentums_matrix)
        )
        trajectory_gradient = jnp.sum(
            acceptance_probabilities * trajectory_gradients, where=~is_divergent
        ) / jnp.sum(acceptance_probabilities, where=~is_divergent)

        log_trajectory_length = jnp.log(trajectory_length)
        updates, optim_state_ = optim.update(
            trajectory_gradient, optim_state, log_trajectory_length
        )
        log_trajectory_length_ = optax.apply_updates(log_trajectory_length, updates)
        new_log_trajectory_length, new_optim_state = jax.lax.cond(
            jnp.isfinite(
                jax.flatten_util.ravel_pytree(log_trajectory_length_)[0]
            ).all(),
            lambda _: (log_trajectory_length_, optim_state_),
            lambda _: (log_trajectory_length, optim_state),
            None,
        )
        new_log_trajectory_length_ma = (
            1.0 - update_weight
        ) * log_trajectory_length_ma + update_weight * new_log_trajectory_length
        new_trajectory_length = jnp.exp(new_log_trajectory_length_ma)

        return ChEESAdaptationState(
            new_step_size,
            new_log_step_size_ma,
            new_trajectory_length,
            new_log_trajectory_length_ma,
            new_da_state,
            new_optim_state,
            next_random_arg_fn(random_generator_arg),
            step + 1,
        )

    def init(random_generator_arg: Array, step_size: float):
        return ChEESAdaptationState(
            step_size=step_size,
            log_step_size_moving_average=0.0,
            trajectory_length=step_size,
            log_trajectory_length_moving_average=0.0,
            da_state=da_init(step_size),
            optim_state=optim.init(step_size),
            random_generator_arg=random_generator_arg,
            step=1,
        )

    def update(
        adaptation_state: ChEESAdaptationState,
        proposed_positions: ArrayLikeTree,
        proposed_momentums: ArrayLikeTree,
        initial_positions: ArrayLikeTree,
        acceptance_probabilities: Array,
        is_divergent: Array,
    ):
        """Update the adaptation state and parameter values.

        Parameters
        ----------
        adaptation_state
            The current state of the adaptation algorithm
        proposed_positions:
            The position proposed by the HMC algorithm of every chain.
        proposed_momentums:
            The momentum variable proposed by the HMC algorithm of every chain.
        initial_positions:
            The initial position at the start of the HMC algorithm of every chain.
        acceptance_probabilities:
            Metropolis-Hastings acceptance probabilty of proposals of every chain.

        Returns
        -------
        New adaptation state that contains the step size and trajectory length of the
        jittered HMC algorithm.

        """
        new_state = compute_parameters(
            proposed_positions,
            proposed_momentums,
            initial_positions,
            acceptance_probabilities,
            is_divergent,
            adaptation_state,
        )

        return new_state

    return init, update


def chees_adaptation(
    logdensity_fn: Callable,
    num_chains: int,
    *,
    jitter_generator: Optional[Callable] = None,
    jitter_amount: float = 1.0,
    target_acceptance_rate: float = OPTIMAL_TARGET_ACCEPTANCE_RATE,
    decay_rate: float = 0.5,
    adaptation_info_fn: Callable = return_all_adapt_info,
) -> AdaptationAlgorithm:
    """Adapt the step size and trajectory length (number of integration steps / step size)
    parameters of the jittered HMC algorthm.

    The jittered HMC algorithm depends on the value of a step size, controlling
    the discretization step of the integrator, and a trajectory length, given by the
    number of integration steps / step size, jittered by using only a random percentage
    of this trajectory length.

    This adaptation algorithm tunes the trajectory length by heuristically maximizing
    the Change in the Estimator of the Expected Square (ChEES) criterion over
    an ensamble of parallel chains. At equilibrium, the algorithm aims at eliminating
    correlations between target dimensions, making the HMC algorithm efficient.

    Jittering requires generating a random sequence of uniform variables in [0, 1].
    However, this adds another source of variance to the sampling procedure,
    which may slow adaptation or lead to suboptimal mixing. To alleviate this,
    rather than use uniform random noise to jitter the trajectory lengths, we use a
    quasi-random Halton sequence, which ensures a more even distribution of trajectory
    lengths.

    Examples
    --------

    An HMC adapted kernel can be learned and used with the following code:

    .. code::

        warmup = blackjax.chees_adaptation(logdensity_fn, num_chains)
        key_warmup, key_sample = jax.random.split(rng_key)
        optim = optax.adam(learning_rate)
        (last_states, parameters), _ = warmup.run(
            key_warmup,
            positions, #PyTree where each leaf has shape (num_chains, ...)
            initial_step_size,
            optim,
            num_warmup_steps,
        )
        kernel = blackjax.dynamic_hmc(logdensity_fn, **parameters).step
        new_states, info = jax.vmap(kernel)(key_sample, last_states)

    Parameters
    ----------
    logdensity_fn
        The log density probability density function from which we wish to sample.
    num_chains
        Number of chains used for cross-chain warm-up training.
    jitter_generator
        Optional function that generates a value in [0, 1] used to jitter the trajectory
        lengths given a PRNGKey, used to propose the number of integration steps. If None,
        then a quasi-random Halton is used to jitter the trajectory length.
    jitter_value
        A percentage in [0, 1] representing how much of the calculated trajectory should be jitted.
    target_acceptance_rate
        Average acceptance rate to target with dual averaging. Defaults to optimal tuning for HMC.
    decay_rate
        Float representing how much to favor recent iterations over earlier ones in the optimization
        of step size and trajectory length. A value of 1 gives equal weight to all history. A value
        of 0 gives weight only to the most recent iteration.
    adaptation_info_fn
        Function to select the adaptation info returned. See return_all_adapt_info
        and get_filter_adapt_info_fn in blackjax.adaptation.base.  By default all
        information is saved - this can result in excessive memory usage if the
        information is unused.

    Returns
    -------
    A function that returns the last cross-chain state, a sampling kernel with the
    tuned parameter values, and all the warm-up states for diagnostics.

    """

    def run(
        rng_key: PRNGKey,
        positions: ArrayLikeTree,
        step_size: float,
        optim: optax.GradientTransformation,
        num_steps: int = 1000,
        *,
        max_sampling_steps: int = 1000,
    ):
        assert all(
            jax.tree_util.tree_flatten(
                jax.tree_util.tree_map(lambda p: p.shape[0] == num_chains, positions)
            )[0]
        ), "initial `positions` leading dimension must be equal to the `num_chains`"
        num_dim = pytree_size(positions) // num_chains

        next_random_arg_fn = lambda i: i + 1
        init_random_arg = 0

        if jitter_generator is not None:
            rng_key, carry_key = jax.random.split(rng_key)
            jitter_gn = lambda i: jitter_generator(
                jax.random.fold_in(carry_key, i)
            ) * jitter_amount + (1.0 - jitter_amount)
        else:
            jitter_gn = lambda i: dynamic_hmc.halton_sequence(
                i, np.ceil(np.log2(num_steps + max_sampling_steps))
            ) * jitter_amount + (1.0 - jitter_amount)

        def integration_steps_fn(random_generator_arg, trajectory_length_adjusted):
            return jnp.asarray(
                jnp.ceil(jitter_gn(random_generator_arg) * trajectory_length_adjusted),
                dtype=int,
            )

        step_fn = dynamic_hmc.build_kernel(
            next_random_arg_fn=next_random_arg_fn,
            integration_steps_fn=integration_steps_fn,
        )

        init, update = base(
            jitter_gn, next_random_arg_fn, optim, target_acceptance_rate, decay_rate
        )

        def one_step(carry, rng_key):
            states, adaptation_state = carry

            keys = jax.random.split(rng_key, num_chains)
            _step_fn = partial(
                step_fn,
                logdensity_fn=logdensity_fn,
                step_size=adaptation_state.step_size,
                inverse_mass_matrix=jnp.ones(num_dim),
                trajectory_length_adjusted=adaptation_state.trajectory_length
                / adaptation_state.step_size,
            )
            new_states, info = jax.vmap(_step_fn)(keys, states)
            new_adaptation_state = update(
                adaptation_state,
                info.proposal.position,
                info.proposal.momentum,
                states.position,
                info.acceptance_rate,
                info.is_divergent,
            )

            return (new_states, new_adaptation_state), adaptation_info_fn(
                new_states, info, new_adaptation_state
            )

        batch_init = jax.vmap(
            lambda p: dynamic_hmc.init(p, logdensity_fn, init_random_arg)
        )
        init_states = batch_init(positions)
        init_adaptation_state = init(init_random_arg, step_size)

        keys_step = jax.random.split(rng_key, num_steps)
        (last_states, last_adaptation_state), info = jax.lax.scan(
            one_step, (init_states, init_adaptation_state), keys_step
        )

        trajectory_length_adjusted = jnp.exp(
            last_adaptation_state.log_trajectory_length_moving_average
            - last_adaptation_state.log_step_size_moving_average
        )
        parameters = {
            "step_size": jnp.exp(last_adaptation_state.log_step_size_moving_average),
            "inverse_mass_matrix": jnp.ones(num_dim),
            "next_random_arg_fn": next_random_arg_fn,
            "integration_steps_fn": lambda arg: integration_steps_fn(
                arg, trajectory_length_adjusted
            ),
        }

        return AdaptationResults(last_states, parameters), info

    return AdaptationAlgorithm(run)  # type: ignore[arg-type]
