import jax.numpy as jnp
import numpy as np

def fourier_features(
    values: jnp.ndarray,
    base_period: float,
    num_frequencies: int,
    ) -> jnp.ndarray:
  """Maps values to sin/cos features for a range of frequencies.

  Args:
    values: Values to compute Fourier features for.
    base_period: The base period to use. This should be greater or equal to the
      range of the values, or to the period if the values have periodic
      semantics (e.g. 2pi if they represent angles). Frequencies used will be
      integer multiples of 1/base_period.
    num_frequencies: The number of frequencies to use, we will use integer
      multiples of 1/base_period from 1 up to num_frequencies inclusive. (We
      don't include a zero frequency as this would just give constant features
      which are redundant if a bias term is present).

  Returns:
    Array with same shape as values except with an extra trailing dimension
    of size 2*num_frequencies, which contains a sin and a cos feature for each
    frequency.
  """
  frequencies = np.arange(1, num_frequencies + 1) / base_period
  angular_frequencies = jnp.array(2 * np.pi * frequencies)
  values_times_angular_freqs = values[..., None] * angular_frequencies
  return jnp.concatenate(
      [jnp.cos(values_times_angular_freqs),
       jnp.sin(values_times_angular_freqs)],
      axis=-1)