"""Utilities for dealing with sparse TF tensors."""
import tensorflow as tf
from typing import Union

# typedefs

# Either a sparse tensor or a dense tensor consisting of
# the indices of a sparse tensor. If a dense tensor, it should be
# an int32/int64 tensor.
SparseOrIndices = Union[tf.Tensor, tf.sparse.SparseTensor]

# Either a dense tensor or a sparse tensor.
MaybeSparseTensor = Union[tf.Tensor, tf.sparse.SparseTensor]


def get_sparse_indices(indices_or_sparse: SparseOrIndices) -> tf.Tensor:
    if isinstance(indices_or_sparse, tf.sparse.SparseTensor):
        return indices_or_sparse.indices
    return indices_or_sparse
