import jax
import jax.numpy as jnp 
import jax.random as random
from typing import Sequence, Union
from jax import vmap


#Code is heavily adapted from https://github.com/acfr/PLNet/blob/main/surrogte_loss/rosenbrock_utils.py

def ackley(x, a=20, b=0.2, c=2 * jnp.pi):
    D = x.shape[0]
    sum_sq_term = -b * jnp.sqrt(jnp.mean(x**2))
    cos_term = jnp.mean(jnp.cos(c * x))
    return -a * jnp.exp(sum_sq_term) - jnp.exp(cos_term) + a + jnp.exp(1)

def schaffer_7(x):
    N = x.shape[0]
    term1 = jnp.sqrt(x**2 + x[jnp.roll(jnp.arange(N), -1)]**2)
    return jnp.sum(term1**0.5 * (1 + jnp.sin(50*term1**0.2)**2))/x.shape[0]

def xin_she_yang_3(x):
    bracket_1 = jnp.exp(-jnp.sum((x/15)**10)) - 2*jnp.exp(-jnp.sum(x**2))
    ans = bracket_1 * jnp.prod(jnp.cos(x)**2)
    return 1 + ans

def schwefel(x): #GM AT 420.98.... but doesn't have a singular global minimum...
    x = x*300
    return  - jnp.sum(x * jnp.sin(jnp.sqrt(jnp.abs(x))))

def xin_she_yang_1(x):
    u = jnp.sum(jnp.abs(x))
    v = jnp.exp(-jnp.sum(jnp.sin(x**2)))
    return u*v

def expanded_schaffer_f6(x): 
    term1 = jnp.sqrt(x[:-1]**2 + x[1:]**2)
    term1 = jnp.concatenate([term1, jnp.sqrt(x[-1]**2 + x[0]**2)])
    return jnp.sum(0.5 + (jnp.sin(term1)**2 - 0.5) / (1 + 0.001 * term1**2)**2)

def alpine_1(x):
    return jnp.sum(jnp.abs(x * jnp.sin(x) + 0.1 * x))

def griewank(x):
    sum_term = jnp.sum(x**2) / 4000
    prod_term = jnp.prod(jnp.cos(x / jnp.sqrt(jnp.arange(1, x.size + 1))))
    output =  (1 + sum_term - prod_term)
    return output

def rastrigin(x): #can't be learned??
    return jnp.sum(x**2 - 10 * jnp.cos(2 * jnp.pi * x))

def Rosenbrock(x):
    f = lambda x, y : 100*(x-1.) ** 2. + 0.5 * (y - x ** 2) ** 2

    single = x.ndim == 1
    if single:
        x = jnp.expand_dims(x, 0)

    N = jnp.shape(x)[-1]
    y = jnp.stack([f(x[..., i], x[..., i+1]) for i in range(N-1)], axis=1)
    y = jnp.mean(y, axis=1)

    if single:
        y = jnp.squeeze(y, 0)

    return (y/26) - 5

def PRosenbrock(x, p):
    f = lambda x, y, a, b: (x- a) ** 2 / 200. + 0.5 * (y - b * x ** 2) ** 2

    single = x.ndim == 1
    if single:
        x = jnp.expand_dims(x, 0)
        p = jnp.expand_dims(p, 0)

    N = jnp.shape(x)[-1]
    y = jnp.stack([f(x[..., i], x[..., i+1], p[..., 0], p[..., 1]) for i in range(N-1)], axis=1)
    y = jnp.mean(y, axis=1)

    if single:
        y = jnp.squeeze(y, 0)
    
    return y 

def Sine(x):
    f = lambda x, y: 0.25*(jnp.sin(8*(x-1.0)-jnp.pi/2) + jnp.sin(8*(y-1.0)-jnp.pi/2)+2.0)

    single = x.ndim == 1
    if single:
        x = jnp.expand_dims(x, 0)

    N = jnp.shape(x)[-1]
    y = jnp.stack([f(x[..., i], x[..., i+1]) for i in range(N-1)], axis=1)
    y = jnp.mean(y, axis=1)

    if single:
        y = jnp.squeeze(y, 0)
    
    return y

def Sampler(
        rng: random.PRNGKey, 
        batches: int, 
        data_dim: int,
        x_min: Union[float, jnp.ndarray] = -2.,
        x_max: Union[float, jnp.ndarray] = 2., 
):
    return random.uniform(rng, (batches, data_dim), minval=x_min, maxval=x_max) 

def MeshField(
    x_range: Sequence[float] = [-2., 2.],
    y_range: Sequence[float] = [-1., 3.],
    n_grid: int = 200
):
    x = jnp.linspace(x_range[0], x_range[1], n_grid)
    y = jnp.linspace(y_range[0], y_range[1], n_grid)
    xx, yy = jnp.meshgrid(x, y)
    z = jnp.concatenate([jnp.reshape(xx,(-1,1)), jnp.reshape(yy, (-1,1))], axis=1)
    return z, xx, yy 

 
