import numpy as np
import scipy as sp
import safeopt


def transform_to_2d(an_array):
    if type(an_array) == list:
        an_array = np.array(an_array)
    if an_array.ndim == 1:
        an_array = np.expand_dims(an_array, axis=0)
    elif an_array.ndim == 0:
        an_array = np.atleast_2d(an_array)
    return an_array


def sample_unit_vec(dim):
    Gaussian_vec = np.random.randn(dim)
    return Gaussian_vec / np.linalg.norm(Gaussian_vec)


def inner_product(a, b):
    return sum([a[i] * b[i] for i in range(len(b))])

def try_sample_utility_func(random_sample, num_knots, problem_dim, bounds,
                       kernel, numerical_epsilon, noise_var,
                       discrete_num_list):
    random_noise = 0.01 + 3 * np.random.rand() ** 2
    func_norm = 3
    sampled_f = lambda x: -0.5 * np.log(1+x/random_noise)
    return func_norm, sampled_f



def try_sample_gp_func(random_sample, num_knots, problem_dim, bounds,
                       kernel, numerical_epsilon, noise_var,
                       discrete_num_list):
    config = dict()
    if random_sample:
        sample_list = []
        for i in range(num_knots):
            loc = np.random.uniform(0.0, 1.0, problem_dim)
            x = np.array([bounds[i][0] +
                         (bounds[i][1] - bounds[i][0]) * loc[i]
                          for i in range(problem_dim)]
                         )
            sample_list.append(x)
        knot_x = np.array(sample_list)
    else:
        knot_x = safeopt.linearly_spaced_combinations(
            bounds,
            discrete_num_list
        )

    config['knot_x'] = knot_x
    knot_cov = kernel.K(config['knot_x']) + \
        np.eye(knot_x.shape[0]) * numerical_epsilon
    knot_cov_cho = sp.linalg.cho_factor(knot_cov)
    fun = safeopt.sample_gp_function(kernel,
                                     bounds,
                                     noise_var,
                                     discrete_num_list
                                     )
    knot_y = fun(knot_x)
    alpha = sp.linalg.cho_solve(knot_cov_cho, knot_y)

    config['knot_cov'] = knot_cov
    config['knot_cov_cho'] = knot_cov_cho
    config['knot_y'] = knot_y
    config['knot_min'] = np.min(knot_y)
    config['alpha'] = alpha
    config['func_norm'] = np.sqrt((knot_y.T @ alpha)[0, 0])

    def sampled_f(x):
        x = np.atleast_2d(x)
        y = kernel.K(x, knot_x).dot(alpha)
        y = np.squeeze(y)
        return y

    return config['func_norm'], sampled_f
