from typing import Any, NamedTuple

from jax import Array
from numpy import bool_, integer, floating
from numpy.typing import NDArray

BoolArray = NDArray[bool_]
FloatArray = NDArray[floating]
IntArray = NDArray[integer]
ArrayLike = Array | FloatArray

SummaryWriter = Any  # Type alias for tf.summary.SummaryWriter


class OfflineData(NamedTuple):
    actions: FloatArray
    dones: BoolArray
    observations: FloatArray
    rewards: FloatArray
    terminals: BoolArray


class OfflineDataWithInfos(NamedTuple):
    data: OfflineData
    infos: dict[str, Any]


class QLearningBatch(NamedTuple):
    actions: FloatArray
    dones: BoolArray
    next_observations: FloatArray
    observations: FloatArray
    rewards: FloatArray


class RegressionBatch(NamedTuple):
    features: FloatArray
    targets: FloatArray


class SaBatch(NamedTuple):
    actions: FloatArray
    observations: FloatArray


class SarsaBatch(NamedTuple):
    actions: FloatArray
    dones: BoolArray
    next_actions: FloatArray
    next_observations: FloatArray
    observations: FloatArray
    rewards: FloatArray


class VLearningBatch(NamedTuple):
    dones: BoolArray
    next_observations: FloatArray
    observations: FloatArray
    rewards: FloatArray
