import jax
import math
import numpy as np

def noisy_levy(x, std=0):
    dim, noise = len(x), std * np.random.normal(0, 1)
    
    w = 1 + (x - 1.0) / 4.0
    val = jax.numpy.sin(jax.numpy.pi * w[0]) ** 2 + \
        jax.numpy.sum((w[1:dim - 1] - 1) ** 2 * (1 + 10 * jax.numpy.sin(jax.numpy.pi * w[1:dim - 1] + 1) ** 2)) + \
        (w[dim - 1] - 1) ** 2 * (1 + jax.numpy.sin(2 * jax.numpy.pi * w[dim - 1])**2)
    return val / math.sqrt(dim) + noise * jax.numpy.mean(x ** 2)

def noisy_ackley(x, std=0):
    dim, noise = len(x), std * np.random.normal(0, 1)
    
    a, b, c = 20, 0.2, math.pi
    part1 = -a * jax.numpy.exp(-b / math.sqrt(dim) * jax.numpy.linalg.norm(x, axis=-1))
    part2 = -(jax.numpy.exp(jax.numpy.mean(jax.numpy.cos(c * x), axis=-1)))
    return part1 + part2 + a + math.e + noise * jax.numpy.mean(x ** 2)

def noisy_sphere(x, std=0):
    dim, noise = len(x), std * np.random.normal(0, 1)
    
    return jax.numpy.linalg.norm(x, axis=-1) / math.sqrt(dim) + noise * jax.numpy.mean(x ** 2)

def noisy_rosenbrock(x, std=0):
    dim, noise = len(x), std * np.random.normal(0, 1)
    
    sum_of_squares = jax.numpy.sum(100.0 * (x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0)
    return sum_of_squares / dim + noise * jax.numpy.mean(x ** 2)