from dataclasses import dataclass
from typing import Iterable

from torchtyping import TensorType


@dataclass
class GeneralElement:
    """
    General element outputted by a data pipeline
    """

    pass


@dataclass
class RLElement:
    """
    Batch element for RL model
    """

    state: Iterable[str] = None  # Context/prompts
    action: TensorType["N"] = None  # Tokens generated by model given prompts
    reward: float = None  # Reward obtained for that generation


@dataclass
class BatchElement:
    """
    General batch element for any transformer to use in its forward pass
    """

    tokens: TensorType["BATCH", "SEQ_LEN"]
    masks: TensorType["BATCH", "SEQ_LEN"]
