from functools import partial
from typing import (
    Generic,
    NamedTuple,
    TypeVar,
)

import jax
import jax.numpy as jnp

T = TypeVar('T')


class EMA(NamedTuple, Generic[T]):
    """
    Izmailov et al. 2019
    "Averaging Weights Leads to Wider Optima and Better Generalization"
    https://doi.org/10.48550/arXiv.1803.05407.
    """

    data: T
    weight: float

    @classmethod
    def create(cls, tree: T) -> 'EMA[T]':
        return cls(jax.tree.map(lambda x: jnp.zeros_like(x), tree), 0)


@partial(jax.jit, static_argnames=('decay',))
def update(ema: EMA[T], value: T, decay: float) -> EMA[T]:
    return EMA(
        jax.tree.map(lambda a, b: a * decay + b, ema.data, value), ema.weight * decay + 1
    )


@jax.jit
def value(ema: EMA[T]) -> T:
    return jax.tree.map(lambda x: x / ema.weight, ema.data)
