from typing import Callable, Tuple

import chex
from chex import Array, PRNGKey
from flax.core.frozen_dict import FrozenDict
from jumanji.types import TimeStep
from optax._src.base import OptState
from typing_extensions import NamedTuple

from mava.types import MavaObservation, State


class LearnerState(NamedTuple):
    """State of the learner."""

    params: FrozenDict
    opt_state: OptState
    key: chex.PRNGKey
    env_state: State
    timestep: TimeStep


class MATNetworkConfig(NamedTuple):
    """Configuration for the MAT network."""

    n_block: int
    n_head: int
    embed_dim: int
    use_swiglu: bool
    use_rmsnorm: bool


ActorApply = Callable[
    [FrozenDict, MavaObservation, PRNGKey],
    Tuple[Array, Array, Array, Array],
]
LearnerApply = Callable[[FrozenDict, MavaObservation, Array, PRNGKey], Tuple[Array, Array, Array]]
