"""Search policies."""

import functools
from typing import Optional, TypeVar

import chex
import jax
import jax.numpy as jnp

from _mctx._src import action_selection, base, qtransforms, search, seq_halving

# UTILITIES


def _mask_invalid_actions(logits, invalid_actions):
    """Returns logits with zero mass to invalid actions.

    Args:
        logits: logits to mask, shape `[B, num_actions]`.
        invalid_actions: a mask with invalid actions. Invalid actions
            have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`.
    """
    if invalid_actions is None:
        return logits
    chex.assert_equal_shape([logits, invalid_actions])
    logits = logits - jnp.max(logits, axis=-1, keepdims=True)  # [B, num_actions]
    # At the end of an episode, all actions can be invalid. A softmax would then
    # produce NaNs, if using -inf for the logits. We avoid the NaNs by using
    # a finite `min_logit` for the invalid actions.

    # Mask invalid actions by setting their logits to a very low value.
    min_logit = jnp.finfo(logits.dtype).min
    return jnp.where(invalid_actions, min_logit, logits)


def _get_logits_from_probs(probs):
    """Returns logits from probabilities.

    Args:
        probs: probabilities to convert to logits, shape `[B, num_actions]`.
    """
    tiny = jnp.finfo(probs.dtype).tiny
    return jnp.log(jnp.maximum(probs, tiny))


def _add_dirichlet_noise(rng_key, probs, *, dirichlet_alpha, dirichlet_fraction):
    """Mixes the probs with Dirichlet noise.

    Args:
        rng_key: random number generator state, the key is consumed.
        probs: probabilities to mix with Dirichlet noise, shape `[B, num_actions]`.
        dirichlet_alpha: concentration parameter for the Dirichlet distribution.
        dirichlet_fraction: float from 0 to 1 interpolating between using only the
            prior policy or just the Dirichlet noise.

    Returns:
        Noisy probabilities, shape `[B, num_actions]`.
    """
    chex.assert_rank(probs, 2)
    chex.assert_type([dirichlet_alpha, dirichlet_fraction], float)

    batch_size, num_actions = probs.shape
    noise = jax.random.dirichlet(
        rng_key,
        alpha=jnp.full([num_actions], fill_value=dirichlet_alpha),
        shape=(batch_size,),
    )
    noisy_probs = (1 - dirichlet_fraction) * probs + dirichlet_fraction * noise
    return noisy_probs


def _apply_temperature(logits, temperature):
    """Returns `logits / temperature`, supporting also temperature=0."""
    # The max subtraction prevents +inf after dividing by a small temperature.
    logits = logits - jnp.max(logits, keepdims=True, axis=-1)
    tiny = jnp.finfo(logits.dtype).tiny
    return logits / jnp.maximum(tiny, temperature)


# POLICIES
def muzero_policy(
    params: base.Params,
    rng_key: chex.PRNGKey,
    root: base.RootT,
    recurrent_fn: base.RecurrentFn,
    num_simulations: int,
    invalid_actions: Optional[chex.Array] = None,
    max_depth: Optional[int] = None,
    loop_fn: base.LoopFn = jax.lax.fori_loop,
    *,
    qtransform: base.QTransform = qtransforms.qtransform_by_parent_and_siblings,
    search_fn: base.SearchFn,
    dirichlet_fraction: chex.Numeric = 0.25,
    dirichlet_alpha: chex.Numeric = 0.3,
    pb_c_init: chex.Numeric = 1.25,
    pb_c_base: chex.Numeric = 19652,
    temperature: chex.Numeric = 1.0,
) -> base.PolicyOutput[None]:
    """Runs MuZero search and returns the `PolicyOutput`.

    In the shape descriptions, `B` denotes the batch dimension.

    Args:
        params: params to be forwarded to root and recurrent functions.
        rng_key: random number generator state, the key is consumed.
        root: a `(prior_logits, value, embedding)` `RootFnOutput`. The
            `prior_logits` are from a policy network. The shapes are
            `([B, num_actions], [B], [B, ...])`, respectively.
        recurrent_fn: a callable to be called on the leaf nodes and unvisited
        actions retrieved by the simulation step, which takes as args
            `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput`
            and the new state embedding. The `rng_key` argument is consumed.
        num_simulations: the number of simulations.
        invalid_actions: a mask with invalid actions. Invalid actions
            have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`.
        max_depth: maximum search tree depth allowed during simulation.
        loop_fn: Function used to run the simulations. It may be required to pass
            hk.fori_loop if using this function inside a Haiku module.
        qtransform: function to obtain completed Q-values for a node.
        search_fn: function to run the search.
        dirichlet_fraction: float from 0 to 1 interpolating between using only the
            prior policy or just the Dirichlet noise.
        dirichlet_alpha: concentration parameter to parametrize the Dirichlet
            distribution.
        pb_c_init: constant c_1 in the PUCT formula.
        pb_c_base: constant c_2 in the PUCT formula.
        temperature: temperature for acting proportionally to `visit_counts**(1 / temperature)`.

    Returns:
        `PolicyOutput` containing the proposed action, action_weights and the used
        search tree.
    """
    rng_key, dirichlet_rng_key, search_rng_key = jax.random.split(rng_key, 3)

    # Adding Dirichlet noise.
    prior_probs = jax.nn.softmax(root.prior_logits)
    noisy_logits = _get_logits_from_probs(
        _add_dirichlet_noise(
            dirichlet_rng_key,
            prior_probs,
            dirichlet_fraction=dirichlet_fraction,
            dirichlet_alpha=dirichlet_alpha,
        )
    )
    root = root.replace(  # type: ignore
        prior_logits=_mask_invalid_actions(noisy_logits, invalid_actions)
    )

    # Running the search.
    interior_action_selection_fn = functools.partial(
        action_selection.muzero_action_selection,
        pb_c_base=pb_c_base,  # type: ignore
        pb_c_init=pb_c_init,  # type: ignore
        qtransform=qtransform,
    )
    root_action_selection_fn = functools.partial(interior_action_selection_fn, depth=0)
    search_tree = search_fn(
        params=params,
        rng_key=search_rng_key,
        root=root,
        recurrent_fn=recurrent_fn,
        root_action_selection_fn=root_action_selection_fn,
        interior_action_selection_fn=interior_action_selection_fn,
        num_simulations=num_simulations,
        max_depth=max_depth,
        invalid_actions=invalid_actions,
        loop_fn=loop_fn,
    )

    # Sampling the proposed action proportionally to the visit counts (with temperature).
    summary = search_tree.summary()
    action_weights = summary.visit_probs
    action_logits = _apply_temperature(
        _get_logits_from_probs(action_weights), temperature
    )
    action = jax.random.categorical(rng_key, action_logits)
    return base.PolicyOutput(
        action=action,  # type: ignore
        action_weights=action_weights,  # type: ignore
        search_tree=search_tree,  # type: ignore
    )


def gumbel_muzero_policy(
    params: base.Params,
    rng_key: chex.PRNGKey,
    root: base.RootT,
    recurrent_fn: base.RecurrentFn,
    num_simulations: int,
    invalid_actions: Optional[chex.Array] = None,
    max_depth: Optional[int] = None,
    loop_fn: base.LoopFn = jax.lax.fori_loop,
    *,
    qtransform: base.QTransform = qtransforms.qtransform_by_parent_and_siblings,
    search_fn: base.SearchFn,
    max_num_considered_actions: int = 16,
    gumbel_scale: chex.Numeric = 1.0,
) -> base.PolicyOutput[action_selection.GumbelMuZeroExtraData]:
    """Runs Gumbel MuZero search and returns the `PolicyOutput`.

    This policy implements Full Gumbel MuZero from
    "Policy improvement by planning with Gumbel": https://openreview.net/forum?id=bERaNdoegnO

    At the root of the search tree, actions are selected by Sequential Halving
    with Gumbel. At non-root nodes (aka interior nodes), actions are selected by
    the Full Gumbel MuZero deterministic action selection.

    In the shape descriptions, `B` denotes the batch dimension.

    Args:
        params: params to be forwarded to root and recurrent functions.
        rng_key: random number generator state, the key is consumed.
        root: a `(prior_logits, value, embedding)` `RootFnOutput`. The
            `prior_logits` are from a policy network. The shapes are
            `([B, num_actions], [B], [B, ...])`, respectively.
        recurrent_fn: a callable to be called on the leaf nodes and unvisited
        actions retrieved by the simulation step, which takes as args
            `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput`
            and the new state embedding. The `rng_key` argument is consumed.
        num_simulations: the number of simulations.
        invalid_actions: a mask with invalid actions. Invalid actions
            have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`.
        max_depth: maximum search tree depth allowed during simulation.
        loop_fn: Function used to run the simulations. It may be required to pass
            hk.fori_loop if using this function inside a Haiku module.
        qtransform: function to obtain completed Q-values for a node.
        search_fn: function to run the search.
        max_num_considered_actions: the maximum number of actions expanded at the
            root node. A smaller number of actions will be expanded if the number of
            valid actions is smaller.
        gumbel_scale: scale for the Gumbel noise. Evalution on perfect-information
            games can use gumbel_scale=0.0.

    Returns:
      `PolicyOutput` containing the proposed action, action_weights and the used
      search tree.
    """

    # Masking invalid actions.
    root = root.replace(  # type: ignore
        prior_logits=_mask_invalid_actions(root.prior_logits, invalid_actions)
    )

    # Generating Gumbel.
    rng_key, gumbel_rng = jax.random.split(rng_key)
    gumbel = gumbel_scale * jax.random.gumbel(
        gumbel_rng, shape=root.prior_logits.shape, dtype=root.prior_logits.dtype
    )  # [batch_size, num_actions]

    # Searching.
    extra_data = action_selection.GumbelMuZeroExtraData(root_gumbel=gumbel)  # type: ignore
    search_tree = search_fn(
        params=params,
        rng_key=rng_key,
        root=root,
        recurrent_fn=recurrent_fn,
        root_action_selection_fn=functools.partial(
            action_selection.gumbel_muzero_root_action_selection,
            num_simulations=num_simulations,
            max_num_considered_actions=max_num_considered_actions,
            qtransform=qtransform,
        ),
        interior_action_selection_fn=functools.partial(
            action_selection.gumbel_muzero_interior_action_selection,
            qtransform=qtransform,
        ),
        num_simulations=num_simulations,
        max_depth=max_depth,
        invalid_actions=invalid_actions,
        extra_data=extra_data,
        loop_fn=loop_fn,
    )
    summary = search_tree.summary()

    # Acting with the best action from the most visited actions.
    # The "best" action has the highest `gumbel + logits + q`.

    # Select the action resulting from sequential halving at the root.
    # This is the action with the highest visit count (didn't get eliminated).
    considered_visit = jnp.max(summary.visit_counts, axis=-1, keepdims=True)  # [B, 1]

    # The completed_qvalues include imputed values for unvisited actions (note that this is
    # simply for scores, and for generating an improved policy. Univisited actions won't be selected).
    completed_qvalues = jax.vmap(qtransform, in_axes=[0, None])(
        search_tree,
        search_tree.ROOT_INDEX,  # type: ignore
    )

    to_argmax = seq_halving.score_considered(
        considered_visit,
        gumbel,
        root.prior_logits,
        completed_qvalues,
        summary.visit_counts,
    )
    action = action_selection.masked_argmax(to_argmax, invalid_actions)

    # Producing action_weights usable to train the policy network.
    completed_search_logits = _mask_invalid_actions(
        root.prior_logits + completed_qvalues, invalid_actions
    )
    action_weights = jax.nn.softmax(completed_search_logits)  # type: ignore
    return base.PolicyOutput(
        action=action,  # type: ignore
        action_weights=action_weights,  # type: ignore
        search_tree=search_tree,  # type: ignore
    )
