from src.utils import *



def mean_with_random(
        samples_R, samples_RE, samples_F,
        probas_R, probas_RE, probas_F,
        center,
        width_r, width_s, width_query, width_far,
        num_dims, project=True):
    # find mean and cov within RE
    probas_RE_normed = probas_RE / probas_RE.sum()
    A = (samples_RE * probas_RE_normed.reshape((-1, 1))).sum(axis=0)
    B = unif(1, center, width_s, num_dims)

    # ask oracle, and reorder to have (A>B) as query outcome
    return A, B


def mean_with_peak(
        samples_R, samples_RE, samples_F,
        probas_R, probas_RE, probas_F,
        center,
        width_r, width_s, width_query, width_far,
        num_dims, project=True):
    probas_RE_normed = probas_RE / probas_RE.sum()
    A = (samples_RE * probas_RE_normed.reshape((-1, 1))).sum(axis=0)
    B_idx = torch.argmax(probas_RE)
    B = samples_RE[B_idx]

    return A, B
