from typing import Callable, Dict, Tuple

from chex import Array, PRNGKey
from flax.core.frozen_dict import FrozenDict
from jumanji.types import TimeStep
from optax._src.base import OptState
from typing_extensions import NamedTuple


class SableNetworkConfig(NamedTuple):
    """Configuration for the Sable network."""

    n_block: int
    n_head: int
    embed_dim: int


class HiddenStates(NamedTuple):
    """Hidden states for the encoder and decoder."""

    encoder: Array
    decoder_self_retn: Array
    decoder_cross_retn: Array


class RecLearnerState(NamedTuple):
    """State of the learner for Memory Sable"""

    params: FrozenDict
    opt_states: OptState
    key: PRNGKey
    env_state: Array
    timestep: TimeStep
    hstates: HiddenStates


class FFLearnerState(NamedTuple):
    """State of the learner for ff-Sable"""

    params: FrozenDict
    opt_states: OptState
    key: PRNGKey
    env_state: Array
    timestep: TimeStep


class Transition(NamedTuple):
    """Transition tuple."""

    done: Array
    action: Array
    value: Array
    reward: Array
    log_prob: Array
    obs: Array
    info: Dict
    done_mask: Array = None


ActorApply = Callable[
    [FrozenDict, Array, Array, HiddenStates, PRNGKey],
    Tuple[Array, Array, Array, Array, HiddenStates],
]
LearnerApply = Callable[
    [FrozenDict, Array, Array, Array, HiddenStates, Array, PRNGKey], Tuple[Array, Array, Array]
]
