"""Lookahead optimizer with adaptive slow step size.

This implementation is adapted and extended from the optax implementation of the lookahead optimizer
References:
- Zhang et al, "Lookahead Optimizer: k steps forward, 1 step back", 2019
    https://arxiv.org/abs/1907.08610
- https://github.com/google-deepmind/optax/tree/main
"""

from dataclasses import dataclass
from typing import NamedTuple

import jax
import jax.numpy as jnp
from optax._src import base
from optax._src.transform import ScaleByAdamState


@dataclass
class LookaheadConfig:
    """Configuration for lookahead optimizer wrapper.

    Attributes:
        slow_step_size: Step size for slow updates (typically 0.5). When
            adaptive_slow_step_size=True, this is the lower bound.
        sync_period: Steps between slow parameter updates (typically 6).
        adaptive_slow_step_size: Compute optimal slow step size using Adam's
            curvature estimate (Section 2.1 of the paper).
        reset_state: Reset fast optimizer state after each sync.
        adam_b2: Adam's beta2 for bias correction when extracting curvature.
    """

    slow_step_size: float
    sync_period: int
    adaptive_slow_step_size: bool = True
    reset_state: bool = False
    adam_b2: float = 0.999


class LookaheadState(NamedTuple):
    """State for the lookahead optimizer.

    Attributes:
        fast_state: Inner optimizer state.
        steps_since_sync: Steps taken since last slow/fast sync.
        slow_params: Slow parameters (used for inference).
    """

    fast_state: base.OptState
    steps_since_sync: jnp.ndarray
    slow_params: base.Params


def _extract_adam_curvature(fast_state: base.OptState, b2: float) -> base.Params | None:
    """Extract bias-corrected second moment from Adam state as curvature estimate."""
    if isinstance(fast_state, ScaleByAdamState):
        bias_correction = 1 - b2**fast_state.count
        return jax.tree.map(lambda v: v / (bias_correction + 1e-8), fast_state.nu)
    elif isinstance(fast_state, (tuple, list)):
        for state in fast_state:
            result = _extract_adam_curvature(state, b2)
            if result is not None:
                return result
    return None


def _compute_optimal_alpha(
    slow_params: base.Params,
    fast_params: base.Params,
    grads: base.Updates,
    curvature: base.Params,
    alpha_low: float,
) -> jax.Array:
    """Compute optimal slow step size via quadratic approximation.

    From Section 2.1 "Selecting the Slow Weights Step Size":
        alpha* = clip(1 + grad·diff / ||diff||²_A, alpha_low, 1)

    where diff = slow - fast and A is the diagonal curvature (empirical Fisher).
    """
    diff = jax.tree.map(lambda s, f: s - f, slow_params, fast_params)

    # ||diff||²_A = sum(A_i * diff_i**2)
    a_weighted_norm_sq = jax.tree.reduce(
        lambda a, b: a + b,
        jax.tree.map(lambda a, d: jnp.sum(a * d**2), curvature, diff),
    )
    # grad · diff
    grad_dot_diff = jax.tree.reduce(
        lambda a, b: a + b,
        jax.tree.map(lambda g, d: jnp.sum(g * d), grads, diff),
    )

    alpha = 1.0 + grad_dot_diff / (a_weighted_norm_sq + 1e-8)
    return jnp.clip(alpha, alpha_low, 1.0)


def wrap_with_lookahead(
    fast_optimizer: base.GradientTransformation,
    config: LookaheadConfig,
) -> base.GradientTransformationExtraArgs:
    """Lookahead optimizer wrapper with optional adaptive slow step size.

    Wraps a fast optimizer (e.g., Adam) and maintains slow parameters that are
    updated every `sync_period` steps. The slow parameters typically generalize
    better and should be used for inference.

    Args:
        fast_optimizer: Inner optimizer (e.g., optax.adam).
        config: Lookahead configuration.

    Returns:
        A GradientTransformationExtraArgs with init and update functions.
    """
    assert config.sync_period >= 1, 'sync_period must be >= 1.'

    def init_fn(params: base.Params) -> LookaheadState:
        return LookaheadState(
            fast_state=fast_optimizer.init(params),
            steps_since_sync=jnp.zeros((), dtype=jnp.int32),
            slow_params=jax.tree.map(lambda x: jnp.array(x, copy=True), params),
        )

    def update_fn(
        updates: base.Updates,
        state: LookaheadState,
        params: base.Params,
        **extra_args,
    ) -> tuple[base.Updates, LookaheadState]:
        raw_grads = updates  # Save before transformation

        updates, fast_state = fast_optimizer.update(
            updates, state.fast_state, params, **extra_args
        )
        sync_next = state.steps_since_sync == (config.sync_period - 1)

        if config.adaptive_slow_step_size:

            def compute_adaptive_alpha():
                curvature = _extract_adam_curvature(fast_state, config.adam_b2)
                assert curvature is not None, 'Adam state not found'
                return _compute_optimal_alpha(
                    state.slow_params, params, raw_grads, curvature, config.slow_step_size
                )

            alpha = jax.lax.cond(
                sync_next,
                compute_adaptive_alpha,
                lambda: config.slow_step_size,
            )
        else:
            alpha = config.slow_step_size

        # On sync: slow moves toward fast, fast resets to new slow position
        # last_diff = fast + updates - slow (position after fast update relative to slow)
        last_diff = jax.tree.map(
            lambda f, u, s: f + u - s, params, updates, state.slow_params
        )

        # slow_update = alpha * sync_next * last_diff
        slow_updates = jax.tree.map(lambda d: alpha * sync_next * d, last_diff)
        # fast_update = updates - (1 - alpha) * sync_next * last_diff
        fast_updates = jax.tree.map(
            lambda u, d: u - sync_next * (1 - alpha) * d, updates, last_diff
        )

        next_slow_params = jax.tree.map(
            lambda p, u: p + u, state.slow_params, slow_updates
        )

        if config.reset_state:
            initial_state = fast_optimizer.init(params)
            fast_state = jax.tree.map(
                lambda cur, init: (1 - sync_next) * cur + sync_next * init,
                fast_state,
                initial_state,
            )

        new_state = LookaheadState(
            fast_state=fast_state,
            steps_since_sync=(state.steps_since_sync + 1) % config.sync_period,
            slow_params=next_slow_params,
        )

        return fast_updates, new_state

    return base.GradientTransformationExtraArgs(init_fn, update_fn)  # type: ignore


def get_slow_params(opt_state: base.OptState) -> base.Params | None:
    """Extract slow parameters from a lookahead optimizer state."""
    if isinstance(opt_state, LookaheadState):
        return opt_state.slow_params
    elif isinstance(opt_state, (tuple, list)):
        for state in opt_state:
            result = get_slow_params(state)
            if result is not None:
                return result
    return None
