import jax.numpy as jnp
import jax


def line_search(x, cost_func, grad, intp_no=20):
    # clip the gradient to avoid numerical instability
    grad_norm = jnp.linalg.norm(grad, axis=-1, keepdims=True)
    grad = jnp.where(grad_norm > 1.0, grad / grad_norm, grad)
    grad_intp = jnp.linspace(1e-5, 1, intp_no)
    for _ in range(grad.ndim-grad_intp.ndim):
        grad_intp = grad_intp[...,None]
    grad_intp = grad_intp[...,None]*grad[None]
    x_intp = x[None] + grad_intp
    cost_intp = jax.vmap(cost_func, 0)(x_intp)[0]
    min_idx = jnp.argmin(cost_intp, axis=0)
    for _ in range(x_intp.ndim-min_idx.ndim):
        min_idx = min_idx[...,None]
    return jnp.take_along_axis(x_intp, min_idx, axis=0).squeeze(0)