from typing import Iterable
import tensorflow as tf


def repeat_to_shape(
    x: tf.Tensor,
    target_shape: Iterable,
    axis: int
) -> tf.Tensor:
    out = x
    for size in target_shape:
        out = tf.repeat(
            tf.expand_dims(
                out,
                axis=axis
            ),
            (size,),
            axis=axis
        )

    return out
