from jax import numpy as np


def get_dataset(dataset, noise=0, random_key=None, random_targets=False):
    test = None
    if dataset == "1D-sine":
        data = np.expand_dims(np.linspace(-4, 4, num=20), axis=1)
        targets = np.sin(data * 2)
        test_data = np.expand_dims(np.linspace(-4, 4, num=1000), axis=1)
        test_targets = np.sin(test_data * 2)
        test = (test_data, test_targets)
    elif dataset == "1D-poly5":
        data = np.expand_dims(np.linspace(-4, 4, num=20), axis=1)
        targets = (data**5) / 32 - (data**3) / 2 + 2 * data - 1
        test_data = np.expand_dims(np.linspace(-4, 4, num=1000), axis=1)
        test_targets = (
            (test_data**5) / 32 - (test_data**3) / 2 + 2 * test_data - 1
        )
        test = (test_data, test_targets)
    return data, targets, test
