# Add all the operations used in the model

import jax.numpy as jnp

def extend_and_repeat(tensor, axis, repeat):
    if repeat > 1:
        return jnp.repeat(jnp.expand_dims(tensor, axis), repeat, axis=axis)
    else:
        return tensor