"""
Flax NNX Metrics version of
https://github.com/google-deepmind/sonnet/blob/v2/sonnet/src/moving_averages.py
"""

from flax.nnx.training.metrics import Metric, MetricState
from jax import Array, numpy as jnp


class ExponentialMovingAverage(Metric):
    """Maintains an exponential moving average for a value.

    Note this module uses debiasing by default. If you don't want this please
    use an alternative implementation.

    This module keeps track of a hidden exponential moving average that is
    initialized as a vector of zeros which is then normalized to give the
    average. This gives us a moving average which isn't biased towards either
    zero or the initial value. Reference (https://arxiv.org/pdf/1412.6980.pdf)

    Initially:

        hidden_0 = 0

    Then iteratively:

        hidden_i = (hidden_{i-1} - value) * (1 - decay)
        average_i = hidden_i / (1 - decay^i)
    """

    def __init__(self, decay: float, values: Array):
        """Creates a debiased moving average module.

        Args:
            decay:
                The decay to use. Note values close to 1 result in a slow decay
                whereas values close to 0 results in faster decay, tracking the
                input values more closely.
            argname:
                An optional string denoting the key-word argument that
                :func:`update` will use to derive the new value.
                Defaults to "values".
        """
        self._decay = decay
        self.count = MetricState(jnp.array(0, dtype=jnp.int32))
        self.values = MetricState(jnp.zeros_like(values))

    def compute(self) -> Array:
        """Compute and return the debiased EMA"""
        count = jnp.astype(self.count.value, self.values.dtype)
        return self.values / (1.0 - jnp.pow(self._decay, count))

    def reset(self) -> None:
        """Reset the EMA"""
        self.values.value = jnp.zeros_like(self.values.value)
        self.count.value = jnp.array(0, dtype=jnp.int32)

    def update(self, *, values: Array) -> None:  # type: ignore[override]
        """In-place update the EMA"""
        self.count += 1
        self.values.value = (self.values - values) * (1 - self._decay)
