"""
From: https://gist.github.com/GaelVaroquaux/ead9898bd3c973c40429

Non-parametric computation of entropy and mutual-information

Adapted by G Varoquaux for code created by R Brette, itself
from several papers (see in the code).

These computations rely on nearest-neighbor statistics
"""
import numpy as np

from scipy.special import gamma, psi
from scipy import ndimage
from scipy.linalg import det
from numpy import pi

if True:
    from sklearn.neighbors import NearestNeighbors
elif False:
    from pykeops.numpy import LazyTensor
else:
    from cuml import NearestNeighbors

from sklearn.utils import shuffle

__all__ = ["entropy", "mutual_information", "entropy_gaussian"]

EPS = np.finfo(float).eps


def fit_knn(X, max_k=1):
    if True:
        knn = NearestNeighbors(n_neighbors=max_k + 1, leaf_size=256, n_jobs=-1)
    else:
        knn = NearestNeighbors(n_neighbors=max_k + 1)

    knn.fit(X)
    return knn


def nearest_distances(knn, X, k=1):
    """
    X = array(N,M)
    N = number of points
    M = number of dimensions

    returns the distance to the kth nearest neighbor for every point in X
    """
    d, _ = knn.kneighbors(X)  # the first nearest neighbor is itself
    return d[:, -1]  # returns the distance to the kth nearest neighbor


def entropy_gaussian(C):
    """
    Entropy of a gaussian variable with covariance matrix C
    """
    if np.isscalar(C):  # C is the variance
        return 0.5 * (1 + np.log(2 * pi)) + 0.5 * np.log(C)
    else:
        n = C.shape[0]  # dimension
        return 0.5 * n * (1 + np.log(2 * pi)) + 0.5 * np.log(abs(det(C)))


def density(X, max_k=1):
    knn = fit_knn(X, max_k)

    def estimator(Xhat):
        # Distance to kth nearest neighbor
        r = nearest_distances(knn, Xhat, k)  # squared distances
        n, d = X.shape
        volume_unit_ball = (pi ** (0.5 * d)) / gamma(0.5 * d + 1)
        """
        F. Perez-Cruz, (2008). Estimation of Information Theoretic Measures
        for Continuous Random Variables. Advances in Neural Information
        Processing Systems 21 (NIPS). Vancouver (Canada), December.

        return d*mean(log(r))+log(volume_unit_ball)+log(n-1)-log(k)
        """
        # return d * np.mean(np.log(r)) + np.log(volume_unit_ball) + np.log(n - 1) - np.log(k)
        return d * np.log(r + np.finfo(X.dtype).eps) + np.log(volume_unit_ball) + psi(n) - psi(max_k)

    return estimator


def split_entropy(X, k=1, num_splits=10):
    shuffled_X = shuffle(X)
    splits = np.array_split(shuffled_X, num_splits)
    estimates = [entropy(split, k) for split in splits]
    mean_estimate = sum(estimates) / num_splits
    # TODO: also compute variance?
    return mean_estimate


def entropy(X, k=1):
    """ Returns the entropy of the X.

    Parameters
    ===========

    X : array-like, shape (n_samples, n_features)
        The data the entropy of which is computed

    k : int, optional
        number of nearest neighbors for density estimation

    Notes
    ======

    Kozachenko, L. F. & Leonenko, N. N. 1987 Sample estimate of entropy
    of a random vector. Probl. Inf. Transm. 23, 95-101.
    See also: Evans, D. 2008 A computationally efficient estimator for
    mutual information, Proc. R. Soc. A 464 (2093), 1203-1215.
    and:
    Kraskov A, Stogbauer H, Grassberger P. (2004). Estimating mutual
    information. Phys Rev E 69(6 Pt 2):066138.
    """
    if True:
        knn = fit_knn(X, k)
        # Distance to kth nearest neighbor
        r = nearest_distances(knn, X, k)  # squared distances
    else:
        X_i = LazyTensor(X[:, None, :])
        Y_j = LazyTensor(X[None, :, :])
        D_ij = ((X_i - Y_j) ** 2).sum(-1)
        r = D_ij.Kmin(k + 1, dim=1).max(-1)

    n, d = X.shape
    volume_unit_ball = (pi ** (0.5 * d)) / gamma(0.5 * d + 1)
    """
    F. Perez-Cruz, (2008). Estimation of Information Theoretic Measures
    for Continuous Random Variables. Advances in Neural Information
    Processing Systems 21 (NIPS). Vancouver (Canada), December.

    return d*mean(log(r))+log(volume_unit_ball)+log(n-1)-log(k)
    """
    # return d * np.mean(np.log(r)) + np.log(volume_unit_ball) + np.log(n - 1) - np.log(k)
    return d * np.mean(np.log(r + np.finfo(X.dtype).eps)) + np.log(volume_unit_ball) + psi(n) - psi(k)


# NOTE: This is the Total Correlation for multiple variables and not the Interaction Information.
def mutual_information(variables, k=1):
    """
    Returns the mutual information between any number of variables.
    Each variable is a matrix X = array(n_samples, n_features)
    where
      n = number of samples
      dx,dy = number of dimensions

    Optionally, the following keyword argument can be specified:
      k = number of nearest neighbors for density estimation

    Example: mutual_information((X, Y)), mutual_information((X, Y, Z), k=5)
    """
    if len(variables) < 2:
        raise AttributeError("Mutual information must involve at least 2 variables")
    all_vars = np.hstack(variables)
    return sum([entropy(X, k=k) for X in variables]) - entropy(all_vars, k=k)


def mutual_information_2d(x, y, sigma=1, normalized=False):
    """
    Computes (normalized) mutual information between two 1D variate from a
    joint histogram.

    Parameters
    ----------
    x : 1D array
        first variable

    y : 1D array
        second variable

    sigma: float
        sigma for Gaussian smoothing of the joint histogram

    Returns
    -------
    nmi: float
        the computed similariy measure

    """
    bins = (256, 256)

    jh = np.histogram2d(x, y, bins=bins)[0]

    # smooth the jh with a gaussian filter of given sigma
    ndimage.gaussian_filter(jh, sigma=sigma, mode="constant", output=jh)

    # compute marginal histograms
    jh = jh + EPS
    sh = np.sum(jh)
    jh = jh / sh
    s1 = np.sum(jh, axis=0).reshape((-1, jh.shape[0]))
    s2 = np.sum(jh, axis=1).reshape((jh.shape[1], -1))

    # Normalised Mutual Information of:
    # Studholme,  jhill & jhawkes (1998).
    # "A normalized entropy measure of 3-D medical image alignment".
    # in Proc. Medical Imaging 1998, vol. 3338, San Diego, CA, pp. 132-143.
    if normalized:
        mi = ((np.sum(s1 * np.log(s1)) + np.sum(s2 * np.log(s2))) / np.sum(jh * np.log(jh))) - 1
    else:
        mi = np.sum(jh * np.log(jh)) - np.sum(s1 * np.log(s1)) - np.sum(s2 * np.log(s2))

    return mi


###############################################################################
# Tests


def test_entropy():
    # Testing against correlated Gaussian variables
    # (analytical results are known)
    # Entropy of a 3-dimensional gaussian variable
    rng = np.random.RandomState(0)
    n = 50000
    d = 3
    P = np.array([[1, 0, 0], [0, 1, 0.5], [0, 0, 1]])
    C = np.dot(P, P.T)
    Y = rng.randn(d, n)
    X = np.dot(P, Y)
    X = np.hstack([X, X])
    H_th = entropy_gaussian(C)
    H_est = sum(entropy(X.T, k=k) for k in range(3, 20)) / (20 - 3)

    # Our estimated entropy should always be less that the actual one
    # (entropy estimation undershoots) but not too much
    print(H_est, H_th)
    np.testing.assert_array_less(H_est, H_th)
    np.testing.assert_array_less(0.9 * H_th, H_est)


def test_mutual_information():
    # Mutual information between two correlated gaussian variables
    # Entropy of a 2-dimensional gaussian variable
    n = 50000
    rng = np.random.RandomState(0)
    # P = np.random.randn(2, 2)
    P = np.array([[1, 0], [0.5, 1]])
    C = np.dot(P, P.T)
    U = rng.randn(2, n)
    Z = np.dot(P, U).T
    X = Z[:, 0]
    X = X.reshape(len(X), 1)
    Y = Z[:, 1]
    Y = Y.reshape(len(Y), 1)
    # in bits
    MI_est = mutual_information((X, Y), k=5)
    MI_th = entropy_gaussian(C[0, 0]) + entropy_gaussian(C[1, 1]) - entropy_gaussian(C)
    # Our estimator should undershoot once again: it will undershoot more
    # for the 2D estimation that for the 1D estimation
    print((MI_est, MI_th))
    np.testing.assert_array_less(MI_est, MI_th)
    np.testing.assert_array_less(MI_th, MI_est + 0.3)


def test_degenerate():
    # Test that our estimators are well-behaved with regards to
    # degenerate solutions
    rng = np.random.RandomState(0)
    x = rng.randn(50000)
    X = np.c_[x, x]
    assert np.isfinite(entropy(X))
    assert np.isfinite(mutual_information((x[:, np.newaxis], x[:, np.newaxis])))
    assert 2.9 < mutual_information_2d(x, x) < 3.1


def test_mutual_information_2d():
    # Mutual information between two correlated gaussian variables
    # Entropy of a 2-dimensional gaussian variable
    n = 50000
    rng = np.random.RandomState(0)
    # P = np.random.randn(2, 2)
    P = np.array([[1, 0], [0.9, 0.1]])
    C = np.dot(P, P.T)
    U = rng.randn(2, n)
    Z = np.dot(P, U).T
    X = Z[:, 0]
    X = X.reshape(len(X), 1)
    Y = Z[:, 1]
    Y = Y.reshape(len(Y), 1)
    # in bits
    MI_est = mutual_information_2d(X.ravel(), Y.ravel())
    MI_th = entropy_gaussian(C[0, 0]) + entropy_gaussian(C[1, 1]) - entropy_gaussian(C)
    print((MI_est, MI_th))
    # Our estimator should undershoot once again: it will undershoot more
    # for the 2D estimation that for the 1D estimation
    np.testing.assert_array_less(MI_est, MI_th)
    np.testing.assert_array_less(MI_th, MI_est + 0.2)


if __name__ == "__main__":
    # Run our tests
    test_entropy()
    test_mutual_information()
    test_degenerate()
    test_mutual_information_2d()
