from functools import partial
from typing import Any, NamedTuple

import jax.numpy as jnp
from jax import tree_util as jtu
from optax import tree_utils as otu
from optax._src import base, numerics


class ApplyIfFiniteState(NamedTuple):
    """State of the `GradientTransformation` returned by `apply_if_finite`.

    Attributes:
      notfinite_count: Number of consecutive gradient updates containing an Inf or
        a NaN. This number is reset to 0 whenever a gradient update without an Inf
        or a NaN is done.
      last_finite: Whether or not the last gradient update contained an Inf or a
        NaN.
      total_notfinite: Total number of gradient updates containing an Inf or
        a NaN since this optimizer was initialised. This number is never reset.
        inner_state: The state of the inner `GradientTransformation`.

    """

    notfinite_count: Any
    last_finite: Any
    total_notfinite: Any
    inner_state: Any


def apply_if_finite(
    inner: base.GradientTransformation, max_consecutive_errors: int
) -> base.GradientTransformation:
    """A function that wraps an optimizer to make it robust to a few NaNs or Infs.

    The purpose of this function is to prevent any optimization to happen if the
    gradients contain NaNs or Infs. That is, when a NaN or Inf is detected in the
    gradients, the wrapped optimizer ignores that gradient update. If the NaNs or
    Infs persist after a given number of updates, the wrapped optimizer gives up
    and accepts the update.

    Args:
      inner: Inner transformation to be wrapped.
      max_consecutive_errors: Maximum number of consecutive gradient updates
        containing NaNs or Infs that the wrapped optimizer will ignore. After
        that many ignored updates, the optimizer will give up and accept.

    Returns:
      New ``GradientTransformationExtraArgs``.
    """

    inner = base.with_extra_args_support(inner)

    def init(params):
        return ApplyIfFiniteState(
            notfinite_count=jnp.zeros([], jnp.int32),
            last_finite=jnp.array(True, jnp.bool_),
            total_notfinite=jnp.zeros([], jnp.int32),
            inner_state=inner.init(params),
        )

    def update(updates, state, params=None, **extra_args):
        inner_state = state.inner_state
        flat_updates = jtu.tree_flatten(updates)[0]
        isfinite = jnp.all(jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates]))
        notfinite_count = jnp.where(
            isfinite,
            jnp.zeros([], jnp.int32),
            numerics.safe_int32_increment(state.notfinite_count),
        )

        def do_update():
            return inner.update(updates, inner_state, params, **extra_args)

        def reject_update():
            return otu.tree_zeros_like(updates), inner_state

        cond = partial(
            jnp.where,
            jnp.logical_or(isfinite, notfinite_count > max_consecutive_errors),
        )
        updates, new_inner_state = jtu.tree_map(cond, do_update(), reject_update())

        return updates, ApplyIfFiniteState(
            notfinite_count=notfinite_count,
            last_finite=isfinite,
            total_notfinite=jnp.where(
                isfinite,
                state.total_notfinite,
                numerics.safe_int32_increment(state.total_notfinite),
            ),
            inner_state=new_inner_state,
        )

    return base.GradientTransformationExtraArgs(init=init, update=update)
