from typing import NamedTuple, TypeVar, Callable
import jaxtyping as jt
import typeguard
from flax import struct

F = TypeVar("F", bound=Callable)


def typed(function: F) -> F:
    return jt.jaxtyped(function, typechecker=typeguard.typechecked)


# Recurrent block
# ExpandedActivations = jt.Float[jt.Array, "*b t e"]
# RNNDiagonal = jt.Float[jt.Array, "e"]
# RNNState = jt.Float[jt.Array, "*b e"]
# Conv1DState = jt.Float[jt.Array, "*b w e"]
# Reset = jt.Bool[jt.Array, "*b t"]

# Machiatto Recurrent Types:
Alpha = jt.Float[jt.Array, "B H L"]
Nu = jt.Float[jt.Array, "B H L D"]
PrevMax = jt.Float[jt.Array, "B H L"]
RNNState = jt.Float[jt.Array, "B L"]
SegmentIds = jt.Float[jt.Array, "B"]
WindowAttState = jt.Float[jt.Array, "BWD"]


# @typed
@struct.dataclass
class LatteCache:
    """Cache for lattent attention part"""

    alpha: Alpha
    nu: Nu
    prev_max: PrevMax


# @typed
@struct.dataclass
class GemmaMachiattoCache:
    """The cache for a gemma-machiatto block"""

    positions: SegmentIds
    conv: RNNState
    latte: LatteCache
    window_att: WindowAttState
