import numpy as np


# Compute f(X) by combining the input dimensions and sine waves
def function_with_freq(X, X_max, X_min, num_waves, frequencies, phase_shifts, amplitudes):
    num_points, num_input_dims = X.shape
    f_X = np.zeros(num_points)

    for i in range(num_waves):
        wave = np.prod(np.sin(2 * np.pi * frequencies[i] * X / (X_max - X_min) + phase_shifts[i]), axis=-1)
        f_X += amplitudes[i] * wave
    return f_X

def get_random_f(freq_min, freq_max, num_waves, num_input_dims, X_max, X_min, y_max, rng=np.random.default_rng()):
    # Generate random frequencies, amplitudes, and phase shifts for each sine wave
    frequencies = (freq_min + rng.random((num_waves, num_input_dims)) * (freq_max - freq_min))
    amplitudes = rng.random(num_waves) * y_max
    phase_shifts = rng.random((num_waves, num_input_dims)) * 2 * np.pi

    f = lambda X: function_with_freq(X, X_max, X_min, num_waves, frequencies, phase_shifts, amplitudes)
    return f
