import jax
import jax.numpy as jnp

from flax import linen as nn


import functools
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union

Array = jnp.ndarray

def spatial_broadcast(x: Array, resolution: Sequence[int]) -> Array:
  """Broadcast flat inputs to a 2D grid of a given resolution."""
  # x.shape = (batch_size, features).
  x = x[:, jnp.newaxis, jnp.newaxis, :]
  return jnp.tile(x, [1, resolution[0], resolution[1], 1])
