from src.utils import *

def mean_along_cov(
        samples_R, samples_RE, samples_F,
        probas_R, probas_RE, probas_F,
        center,
        width_r, width_s, width_query, width_far,
        num_dims, project=True):
    # find mean and cov within RE
    mean, cov = mean_and_cov(samples_RE, probas_RE)

    eig_vals, eig_vecs = torch.linalg.eigh(cov)
    max_eigvec = eig_vecs[-1]
    max_eigval = eig_vals[-1]
    A = mean + max_eigvec * max_eigval
    B = mean - max_eigvec * max_eigval

    if project:
        A = project(A, center, width_query)
        B = project(B, center, width_query)

    return A, B

def mean_along_cov_mindist(
        samples_R, samples_RE, samples_F,
        probas_R, probas_RE, probas_F,
        center,
        width_r, width_s, width_query, width_far,
        num_dims, project=True, min_dist = 0.2):
    # find mean and cov within RE
    mean, cov = mean_and_cov(samples_RE, probas_RE)

    eig_vals, eig_vecs = torch.linalg.eigh(cov)
    max_eigvec = eig_vecs[-1]
    max_eigval = eig_vals[-1]
    delta = max_eigvec * max_eigval
    len_delta = torch.linalg.norm(delta)
    if len_delta < width_r*min_dist:
        delta/=len_delta/width_r/min_dist

    A = mean + delta
    B = mean - delta

    if project:
        A = project(A, center, width_query)
        B = project(B, center, width_query)

    return A, B


def mean_with_random(
        samples_R, samples_RE, samples_F,
        probas_R, probas_RE, probas_F,
        center,
        width_r, width_s, width_query, width_far,
        num_dims, project=True):
    # find mean and cov within RE
    A, _ = mean_and_cov(samples_RE, probas_RE, compute_cov=False)
    B = unif(1, center, width_s, num_dims)
    return A, B

def random_with_random(
        samples_R, samples_RE, samples_F,
        probas_R, probas_RE, probas_F,
        center,
        width_r, width_s, width_query, width_far,
        num_dims, project=True):
    # find mean and cov within RE
    AB = unif(2, center, width_s, num_dims)
    A = AB[0].reshape((-1,num_dims))
    B = AB[1].reshape((-1,num_dims))
    return A, B


def mean_with_peak(
        samples_R, samples_RE, samples_F,
        probas_R, probas_RE, probas_F,
        center,
        width_r, width_s, width_query, width_far,
        num_dims, project=True):
    A, _ = mean_and_cov(samples_RE, probas_RE, compute_cov=False)
    B_idx = torch.argmax(probas_RE)
    B = samples_RE[B_idx]

    return A, B
