import abc


class AbstractTrainBatch:
    def __init__(self):
        pass


class AbstractTrainIterator(abc.ABC):
    """
    Abstract class for a train iterator. Train iterators are used to iterate over the data in the training loop. Only
    responsible for the train data. The validation and test data is handled by the environment.
    Possible implementations are:
    - Trajectory iterator: Iterates over the trajectories and returns auxiliary sub trajectories.
    - Step iterator: MGN style iterator that iterates over the steps of the trajectories.
    """

    def __init__(self, config, train_trajs, device):
        self.config = config
        self._train_trajs = train_trajs
        self.device = device

    @abc.abstractmethod
    def __iter__(self):
        raise NotImplementedError

    @abc.abstractmethod
    def __next__(self) -> AbstractTrainBatch:
        raise NotImplementedError

    @abc.abstractmethod
    def refresh_iterator(self):
        raise NotImplementedError
