from typing import Callable, NamedTuple

import flax
import jax
import numpy as np
import optax
from flax import struct
from flax.linen.fp8_ops import OVERWRITE_WITH_GRADIENT
from flax.training.train_state import TrainState


class RLTrainState(TrainState):  # type: ignore[misc]
    target_params: flax.core.FrozenDict  # type: ignore[misc]
    cumulative_losses: int

    # -- Variables for target networks --
    selected_target_idx: jax.Array
    # The function to apply to the selected target
    aggregate_target_qf: Callable = struct.field(pytree_node=False)

    # -- Variables for Q networks used in the actor loss --
    selected_policy_idx: jax.Array
    # Epsilon to decide if the Q networks used in the actor loss (aggregate_policy_qf)
    # are sampled from the selected targets or from random network indexes
    epsilon_schedule: Callable = struct.field(pytree_node=False)
    # The function to apply to the selected networks in the actor loss
    aggregate_policy_qf: Callable = struct.field(pytree_node=False)

    def apply_gradients(self, *, grads, **kwargs):
        """Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value.

        Note that internally this function calls ``.tx.update()`` followed by a call
        to ``optax.apply_updates()`` to update ``params`` and ``opt_state``.

        Args:
        grads: Gradients that have the same pytree structure as ``.params``.
        **kwargs: Additional dataclass attributes that should be ``.replace()``-ed.

        Returns:
        An updated instance of ``self`` with ``step`` incremented by one, ``params``
        and ``opt_state`` updated by applying ``grads``, and additional attributes
        replaced as specified by ``kwargs``.
        """
        if OVERWRITE_WITH_GRADIENT in grads:
            grads_with_opt = grads["params"]
            params_with_opt = self.params["params"]
        else:
            grads_with_opt = grads
            params_with_opt = self.params

        updates_new_opt_state = self.tx.update(grads_with_opt, self.opt_state, params_with_opt)
        updates = []
        new_opt_state = []
        [
            (updates.append(update), new_opt_state.append(new_opt_state_))
            for update, new_opt_state_ in updates_new_opt_state
        ]
        new_params_with_opt = jax.tree.map(optax.apply_updates, params_with_opt, updates)

        # As implied by the OWG name, the gradients are used directly to update the
        # parameters.
        if OVERWRITE_WITH_GRADIENT in grads:
            new_params = {
                "params": new_params_with_opt,
                OVERWRITE_WITH_GRADIENT: grads[OVERWRITE_WITH_GRADIENT],
            }
        else:
            new_params = new_params_with_opt
        return self.replace(
            step=self.step + 1,
            params=new_params,
            opt_state=new_opt_state,
            **kwargs,
        )


class BatchNormTrainState(TrainState):  # type: ignore[misc]
    batch_stats: flax.core.FrozenDict  # type: ignore[misc]


class ReplayBufferSamplesNp(NamedTuple):
    observations: np.ndarray
    actions: np.ndarray
    next_observations: np.ndarray
    dones: np.ndarray
    rewards: np.ndarray
