import numpy as np
import torch
from scipy import special

def mean_and_cov(samples, probas, compute_cov=True):

    assert len(samples.shape)==2
    assert len(probas.shape)==2

    # find mean and cov within RE
    probas_normed = probas / probas.sum()
    mean = (samples * probas_normed).sum(axis=0, keepdims=True)

    if compute_cov:
        cov = torch.cov(samples.T, aweights=probas_normed.flatten())
    else:
        cov = None

    return mean, cov

def project(samples, center, max_dist):
    assert len(samples.shape)==2

    dist = np.linalg.norm(samples - center, axis=-1)
    if dist > max_dist:
        samples /= dist

    return samples
def oracle(A, B, target, gamma):
    p_At = proba(A, B, target, gamma)
    if torch.rand(1) < p_At:
        A, B = A, B
    else:
        A, B = B, A

    return A, B

def unif(num_samples, center, width, num_dims):
    samples = torch.zeros((num_samples, num_dims))
    counter = 0
    while counter < num_samples:
        s = (torch.rand((num_samples, num_dims,)) - 0.5) * 2 * width + center
        s = s[torch.linalg.norm(s - center, axis=-1) < width]

        samples[counter:counter + len(s)] = s[:min(num_samples - counter, len(s))]
        counter += len(s)
    return samples


def proba(A, B, points, gamma):
    """
    the gamma-CKL choice model
    proba(A,B, points)[i] = P( A > B | x_t = points[i])
    """
    da = torch.linalg.norm(points - A, axis=-1) ** gamma
    db = torch.linalg.norm(points - B, axis=-1) ** gamma

    # if we zoom in too much, it is possible that the distance to both query points becomes 0
    # in this case we assume that the outcome probability is 50/50
    nan_mask = (da + db) == 0
    da[nan_mask] = 1
    db[nan_mask] = 1

    return (db / ((da + db))).reshape((-1,1))


def log_prob(A, B, points, gamma):
    return torch.log(proba(A, B, points, gamma))


# calculate the log probability for samples given outcomes
# we need to do this for all newly created samples
def log_likelihood(samples, outcomes, gamma):
    ll = torch.zeros((len(samples),1))
    for (A, B) in outcomes:
        ll = ll + log_prob(A, B, samples, gamma)
    return ll


def sphere_volume(radius, num_dims):
    return np.pi ** (num_dims / 2) * special.gamma(num_dims / 2 + 1) * radius ** num_dims


def generate_outcomes(target, center, width_query,num_dims, num_outcomes=1, gamma=5):
    # sample outcomes
    outcomes = []
    for i in range(num_outcomes):
        A, B = unif(2, center, width_query, num_dims)
        A = A.reshape((1,-1))
        B = B.reshape((1,-1))
        # while torch.linalg.norm(A - B) < q_width / 5:
        #    A, B = unif(2, center, q_width, num_dims)

        # ask oracle, and reorder to have (A>B) as query outcome
        p_At = proba(A, B, target, gamma)
        if torch.rand(1) < p_At:
            A, B = A, B
        else:
            A, B = B, A

        outcomes.append((A, B))

    return outcomes
