from functools import partial

import jax


def grid_search_1d(objective_fun, lookup_range, **kwargs):
    jit_objective_fun = jax.jit(partial(objective_fun, **kwargs))
    objective_values = jax.vmap(jit_objective_fun)(lookup_range)
    return lookup_range[jax.numpy.argmin(objective_values)], objective_values.min()
