from typing import List, Optional

from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.policy.sample_batch import SampleBatch


def load_sample_batches(
        inputs: List[str],
        debug_size: Optional[int] = None,
        debug_size_mode: str = "ordered",
) -> SampleBatch:
    """Loads sample batches from inputs into memory and concatenates them into a single `SampleBatch`.

    Args:
        inputs: List of input filepath patterns. Same as what rllib.offline.json_reader.JsonReader takes.
        debug_size: If provided, limits the number of samples to this number.
        debug_size_mode: Mode for loading debug_size.
            "ordered": loads the first debug_size elements, which means they are likely correlated.
                This mode should be used when you want to quickly load the data.
            "shuffled": loads all the data then randomly selects debug_size elements.
                This mode should be used when you want to load a limited amount of data, but in a minimally
                correlated fashion.

    Returns:
        A SampleBatch containing the data loaded into memory.
    """
    reader = JsonReader(inputs)
    if debug_size is None:
        batches = list(reader.read_all_files())
        return SampleBatch.concat_samples(batches)
    else:
        num_samples = 0
        batches = []
        for batch in reader.read_all_files():
            batches.append(batch)
            num_samples += len(batch)
            if num_samples >= debug_size and debug_size_mode == "ordered":
                break
        samples = SampleBatch.concat_samples(batches)
        if debug_size_mode == "shuffled":
            samples.shuffle()
        samples = samples.slice(0, debug_size)
        return samples
