# safe_dataset.py

from torch.utils.data import IterableDataset
from transformers.utils import logging

# Get a logger instance to print warning messages.
logger = logging.get_logger(__name__)

class SafeIterator:
    """
    A safe iterator wrapper. When the underlying iterator throws an exception,
    it catches the exception, logs a warning, and continues trying to get the next element.
    """
    def __init__(self, iterator):
        self.iterator = iterator

    def __iter__(self):
        return self

    def __next__(self):
        while True:  # Continuously loop until a sample is successfully retrieved or the iteration ends.
            try:
                # Try to get the next element from the original iterator.
                return next(self.iterator)
            except StopIteration:
                # If it's a StopIteration, it means the iteration has ended normally and must be re-raised.
                raise
            except Exception as e:
                # Catch all other types of exceptions (e.g., data parsing errors, network errors).
                logger.warning(f"A skippable error was encountered during data loading: {e}")
                # The loop will continue, automatically trying to get the next data point.

class SafeIterableDataset(IterableDataset):
    """
    A safe IterableDataset wrapper.
    It ensures that any recoverable errors encountered during iteration do not interrupt the entire training process.
    """
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset

    def __iter__(self):
        # Return a more robust iterator wrapped by SafeIterator.
        return SafeIterator(iter(self.dataset))