from typing import NamedTuple

import jax
import optax
from optax._src.transform import ScaleByAdamState


class EmptyState(NamedTuple):
    """An empty state for the simplest stateless transformations."""


def init_empty_state(params) -> EmptyState:
    """Init function for a :class:`GradientTransformation` with empty state."""
    del params
    return EmptyState()


def scale(step_size: float) -> optax.GradientTransformation:
    """Scale updates by some fixed scalar `step_size`.

    Args:
      step_size: A scalar corresponding to a fixed scaling factor for updates.

    Returns:
      A :class:`optax.GradientTransformation` object.
    """

    def update_fn(updates, state, params=None):
        del params
        updates = jax.tree.map(lambda g: step_size * g, updates)
        return updates, state

    return optax.GradientTransformation(init_empty_state, update_fn)


def scale_adam_momentum(opt_state: optax.OptState, factor: float) -> optax.OptState:
    """Scale the first moment buffers in an Adam-style state by ``factor``.

    This walks any nested optax state (e.g. apply_if_finite, lookahead, chain)
    and multiplies Adam's ``mu`` accumulator by ``factor`` without touching
    the other statistics.
    """

    def _scale(state):
        if isinstance(state, EmptyState):
            return state
        if isinstance(state, ScaleByAdamState):
            return ScaleByAdamState(
                count=state.count,
                mu=jax.tree.map(lambda m: factor * m, state.mu),
                nu=state.nu,
            )

        # Common optax wrappers that hold an inner optimizer state
        if hasattr(state, 'inner_state'):
            new_inner = _scale(state.inner_state)
            if hasattr(state, '_replace'):
                return state._replace(inner_state=new_inner)
            return state

        if hasattr(state, 'fast_state'):  # lookahead wrapper
            new_fast = _scale(state.fast_state)
            if hasattr(state, '_replace'):
                return state._replace(fast_state=new_fast)
            return state

        if isinstance(state, list):
            return [_scale(s) for s in state]

        if isinstance(state, tuple):
            try:
                return type(state)(*(_scale(s) for s in state))
            except TypeError:
                return tuple(_scale(s) for s in state)

        return state

    return _scale(opt_state)


def get_adam_momentum(opt_state: optax.OptState) -> optax.Updates | None:
    """Return Adam's first moment buffer (``mu``) from a nested optax state."""

    def _get(state):
        if isinstance(state, ScaleByAdamState):
            return state.mu

        if hasattr(state, 'inner_state'):
            result = _get(state.inner_state)
            if result is not None:
                return result

        if hasattr(state, 'fast_state'):  # lookahead wrapper
            result = _get(state.fast_state)
            if result is not None:
                return result

        if isinstance(state, list):
            for s in state:
                result = _get(s)
                if result is not None:
                    return result

        if isinstance(state, tuple):
            for s in state:
                result = _get(s)
                if result is not None:
                    return result

        return None

    return _get(opt_state)


def set_adam_momentum(
    opt_state: optax.OptState, new_momentum: optax.Updates
) -> optax.OptState:
    """Set Adam's first moment buffer (``mu``) inside a nested optax state.

    Raises:
        ValueError: If no Adam state is found or the momentum tree is incompatible.
    """
    found = False

    def _set(state):
        nonlocal found
        if isinstance(state, EmptyState):
            return state
        if isinstance(state, ScaleByAdamState):
            found = True
            try:
                new_mu = jax.tree.map(lambda new, _: new, new_momentum, state.mu)
            except Exception as err:  # tree mismatch
                raise ValueError(
                    'New momentum tree is incompatible with the Adam state.'
                ) from err
            return ScaleByAdamState(
                count=state.count,
                mu=new_mu,
                nu=state.nu,
            )

        if hasattr(state, 'inner_state'):
            new_inner = _set(state.inner_state)
            if hasattr(state, '_replace'):
                return state._replace(inner_state=new_inner)
            return state

        if hasattr(state, 'fast_state'):  # lookahead wrapper
            new_fast = _set(state.fast_state)
            if hasattr(state, '_replace'):
                return state._replace(fast_state=new_fast)
            return state

        if isinstance(state, list):
            return [_set(s) for s in state]

        if isinstance(state, tuple):
            try:
                return type(state)(*(_set(s) for s in state))
            except TypeError:
                return tuple(_set(s) for s in state)

        return state

    updated_state = _set(opt_state)
    if not found:
        raise ValueError('No Adam momentum found in optimizer state.')
    return updated_state
