import numpy as np
import torch


# Sampling scheme
def sample_initial_inputs(n_samples, search_space, method='uniform'):
    '''
    Sample initial data points
    :param n_samples: the number of initial data points to sample
    :param size_search_space: search space to sample data points from
    :return: initial data points in X values and y values
    '''
    
    # array for the points x' and repective f(x') that we sample using the acquisition function
    X_sample = []
    
    if method == 'uniform':
        # sample uniformly distinctly data points
        x_samples = np.linspace(0,len(search_space)-1, n_samples)
        x_samples = [int(x) for x in x_samples]

        for i in x_samples:
            xt = search_space[i]
            X_sample.append(xt)

    elif method == 'random':
        # get some random samples
        for _ in range(n_samples):
            i = np.random.randint(0, size_search_space)
            xt = search_space[i]
            X_sample.append(xt)
    
    
    return X_sample


# Define acqusition functions
def highest_variance(upper, lower):
    """
    Returns the order of points to query next
    """
    dims = len(upper.shape)
    
    if dims == 1:
        ids = torch.argsort(upper-lower, descending=True)
    else:
        ids = torch.argsort(torch.mean(upper-lower, axis=1), descending=True)
    
    return ids