import bisect
from typing import List, Tuple


def cumulative_index_to_multi_index(cumulative_index: int, cumulative_sizes: List[int]) -> Tuple[int, int]:
    """Convert a cumulative index into a jagged array into a multi-index.

    Suppose you have an array like:
    [
        [0,1,2],
        [3,4],
    ]
    And you're given an index into that array such that it's the cumulative index.
    For example, take a cumulative index of 3. That cumulative index would correspond
    to the second array, and the first element of that second array.
    This function takes the cumulative index and the cumulative sizes of the jagged arrays,
    and returns the pair of indices (which array, and then the index into that array).

    Args:
        cumulative_index: The cumulative index across all the arrays.
        cumulative_sizes: The cumulatives sizes of the arrays.
            If `sizes` is a list of the sizes of the individual arrays, then
            this argument is equal to `np.cumsum(sizes)`.
    
    Returns:
        The pair of indices as described above.
    """
    assert len(cumulative_sizes) > 0
    assert cumulative_index < cumulative_sizes[-1]
    index_0 = bisect.bisect_right(cumulative_sizes, cumulative_index)
    if index_0 == 0:
        index_1 = cumulative_index
    else:
        index_1 = cumulative_index - cumulative_sizes[index_0 - 1]
    return index_0, index_1
