import functools
from dataclasses import dataclass
from typing import Callable, Generic, NamedTuple, Protocol, TypeVar

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from nix.utils.jax_utils import jit
from nix.utils.tree_utils import tree_add, tree_sub
from jaxtyping import ArrayLike, PyTree

T = TypeVar('T')
S = TypeVar('S')
V = TypeVar('V')
AverageState = PyTree


class Average(Protocol[T, S, V]):
    config: dict[str, int]
    state_init: Callable[[T], V]
    to_state: Callable[[T, V, bool], V]
    normalize: bool

    def __init__(
        self,
        config: dict[str, int],
        state_init: Callable[[T], V] = functools.partial(jtu.tree_map, jnp.zeros_like),  # type: ignore
        to_state: Callable[[T, V, bool], V] = lambda x, _, __: x,  # type: ignore
        normalize: bool = True,
    ): ...

    def init(self, x: T) -> S: ...

    def add(self, state: S, x: T) -> S: ...

    def values(self, state: S, x: T | None = None) -> dict[str, V]: ...

    def norms(self, state: S) -> dict[str, jax.Array]: ...


class MovingAverageState(NamedTuple, Generic[T, S]):
    history: T
    states: S
    idx: jax.Array
    kahan_corrections: S


def kahan_update(x_sum: T, x: T, c: T) -> tuple[T, T]:
    y = tree_sub(x, c)
    t = tree_add(x_sum, y)
    c = tree_sub(tree_sub(t, x_sum), y)
    return t, c


class UpdateTuple(NamedTuple, Generic[T]):
    new_state: T
    new_correction: T


@dataclass(frozen=True)
class MovingAverage(Average[T, MovingAverageState[T, S], S]):
    config: dict[str, int]
    state_init: Callable[[T], S] = functools.partial(jtu.tree_map, jnp.zeros_like)  # type: ignore
    to_state: Callable[[T, S, bool], S] = lambda x, _, __: x  # type: ignore
    normalize: bool = True

    @property
    def sizes(self) -> list[int]:
        return jtu.tree_leaves(self.config)

    @property
    def max_size(self) -> int:
        return max(self.sizes)

    def init(self, x: T):
        states = jtu.tree_map(lambda *_: self.state_init(x), self.config)
        return MovingAverageState[T, S](
            history=jtu.tree_map(
                lambda x: jnp.zeros((self.max_size, *x.shape), dtype=x.dtype), x
            ),
            states=states,
            idx=jnp.zeros((), dtype=jnp.int32),
            kahan_corrections=jtu.tree_map(jnp.zeros_like, states),
        )

    def add(self, state: MovingAverageState[T, S], x: T):
        def removed_elements(size: int) -> T:
            to_remove = (state.idx - size) % self.max_size
            return jtu.tree_map(lambda x: x[to_remove], state.history)

        def update_state(size: int, state: S, correction: S) -> tuple[S, S]:
            to_add = self.to_state(x, state, False)
            state, new_correction = kahan_update(state, to_add, correction)
            to_sub = jtu.tree_map(
                jnp.negative, self.to_state(removed_elements(size), state, True)
            )
            state, new_correction = kahan_update(state, to_sub, new_correction)
            return UpdateTuple(state, new_correction)

        new_states = jtu.tree_map(
            update_state, self.config, state.states, state.kahan_corrections
        )
        return MovingAverageState[T, S](
            history=jtu.tree_map(
                lambda x, y: x.at[state.idx % self.max_size].set(y), state.history, x
            ),
            states=jtu.tree_map(
                lambda s: s.new_state,
                new_states,
                is_leaf=lambda x: isinstance(x, UpdateTuple),
            ),
            kahan_corrections=jtu.tree_map(
                lambda s: s.new_correction,
                new_states,
                is_leaf=lambda x: isinstance(x, UpdateTuple),
            ),
            idx=state.idx + 1,
        )

    def norms(self, state: MovingAverageState) -> dict[str, jax.Array]:
        return jtu.tree_map(lambda size: jnp.minimum(state.idx, size), self.config)

    def values(
        self, state: MovingAverageState[T, S], x: T | None = None
    ) -> dict[str, S]:
        if x is not None:
            return jtu.tree_map(
                lambda norm, vals: jtu.tree_map(
                    lambda x, y: jnp.where(
                        state.idx > 0, x * (1 / norm if self.normalize else 1), y
                    ),
                    vals,
                    self.to_state(x, vals, False),
                ),
                self.norms(state),
                state.states,
            )
        return jtu.tree_map(
            lambda norm, vals: jtu.tree_map(
                lambda x: x * (1 / norm if self.normalize else 1), vals
            ),
            self.norms(state),
            state.states,
        )


class EMA(NamedTuple, Generic[T]):
    data: T
    weight: jax.Array


def ema_make(tree: T) -> EMA[T]:
    return EMA(jtu.tree_map(lambda x: jnp.zeros_like(x), tree), jnp.zeros(()))


@jit
def ema_update(ema: EMA[T], value: T, decay: ArrayLike) -> EMA[T]:
    return EMA(
        jtu.tree_map(lambda a, b: a * decay + b, ema.data, value),
        ema.weight * decay + 1,
    )


@jit
def ema_value(ema: EMA[T], backup: T | None = None) -> T:
    if backup is None:
        backup = ema.data
    is_nan = ema.weight == 0
    return jtu.tree_map(
        lambda x, y: jnp.where(is_nan, y, x / ema.weight), ema.data, backup
    )


@dataclass(frozen=True)
class ExponentialMovingAverage(Average):
    config: dict[str, int]
    state_init: Callable[[T], S] = functools.partial(jtu.tree_map, jnp.zeros_like)  # type: ignore
    to_state: Callable[[T, S, bool], S] = lambda x, _: x  # type: ignore
    normalize: bool = True

    @property
    def alphas(self):
        return jtu.tree_map(lambda x: 1 - 1 / x, self.config)

    def init(self, x):
        return jtu.tree_map(lambda _: ema_make(self.state_init(x)), self.config)

    def add(self, state, x):
        return jtu.tree_map(
            lambda alpha, tree: ema_update(tree, self.to_state(x, tree, False), alpha),
            self.alphas,
            state,
        )

    def values(self, state, x=None):
        result = jtu.tree_map(
            lambda _, tree: ema_value(tree, self.to_state(x, tree, False)),
            self.config,
            state,
        )
        if not self.normalize:
            result = jtu.tree_map(
                lambda a, t: jtu.tree_map(lambda x: a * x, t), self.norms, result
            )
        return result

    def norms(self, state):
        return jtu.tree_map(lambda _, tree: tree.weight, self.config, state)


class RunningAverage:
    history_size: int
    history: PyTree = None
    _value: PyTree = None
    step: int = 0
    normalizer: int = 0

    def __init__(self, history_size: int):
        self.history_size = history_size

    def __call__(self, data: PyTree):
        data = jtu.tree_map(np.asarray, data)

        def fill_history(hist, data):
            hist[self.step] = data

        if self._value is None:
            self._value = data
            self.history = jtu.tree_map(
                lambda x: np.zeros((self.history_size, *x.shape), dtype=x.dtype), data
            )
            self.step = 1
        else:

            def update_value(value, new_data, history):
                return value + new_data - history[self.step]

            self._value = jtu.tree_map(
                update_value,
                self._value,
                data,
                self.history,
            )
        jtu.tree_map(fill_history, self.history, data)
        self.step = (self.step + 1) % self.history_size
        self.normalizer = min(self.normalizer + 1, self.history_size)
        return self.value

    @property
    def value(self):
        normalizer = min(self.normalizer, self.history_size)
        return jtu.tree_map(lambda x: x / normalizer, self._value)
