import numpy as np
from sklearn.metrics import mutual_info_score

from .linalg import allpairsdistance, center_matrix

def entropy(x, bins=10, alpha=0.5):
    '''Compute entropy (in bits) of dataset

    Parameters
    ----------
    x : np.ndarray
        Dataset
    bins : int or list, optional
        Bins to use for histogram
    alpha : float, optional
        Laplace smoothing constant for empty bins

    Returns
    ----------
    float
        Entropy of x, measured in bits
    '''
    counts = np.histogram(x, bins)[0]
    n_total = len(x)
    px = (counts + alpha) / (n_total + alpha * len(counts))
    h = -sum(px * np.log(px))
    return h / np.log(2)

def mutual_info(x, y, bins=10, alpha=0.5):
    '''Compute mutual information (in bits) of two datasets

    Parameters
    ----------
    x : np.ndarray
        Dataset
    y : np.ndarray
        Dataset of the same size as x
    bins : int or list, optional
        Bins to use for histogram
    alpha : float, optional
        Laplace smoothing constant for empty bins

    Returns
    ----------
    float
        Mutual information between x and y, measured in bits
    '''
    assert len(x) == len(y)
    count_xy = np.histogram2d(x,y,bins)[0]
    smoothed_count_xy = count_xy + alpha
    mi = mutual_info_score(None, None, contingency=smoothed_count_xy)
    return mi / np.log(2)

def iqr(x, y, bins=10, alpha=0.5):
    '''Compute information quality ratio (IQR) for two datasets

    Parameters
    ----------
    x : np.ndarray
        Dataset
    y : np.ndarray
        Dataset of the same size as x
    bins : int or list, optional
        Bins to use for histogram
    alpha : float, optional
        Laplace smoothing constant for empty bins

    Returns
    ----------
    float
        Information quality ratio between x and y
    '''
    mi = mutual_info(x, y, bins, alpha)
    h_x = entropy(x, bins, alpha)
    h_y = entropy(y, bins, alpha)
    h_xy = h_x + h_y - mi
    return (mi / h_xy)

def mmig(x):
    ''' Compute Mean Mutual Information Gap (MMIG)

    Description
    -----------
    The Mean Mutual Information Gap metric measures the disentanglement of a
    representation by computing its mutual information with a known ground-truth
    vector, variable by variable. Specifically, MMIG is the mean difference in
    MI (relative to each ground-truth variable) between the latent variable that
    has the highest MI and the other latent variables in the representation.

    Parameters
    ----------
    x : np.ndarray
        2-D array of mutual information scores, where rows correspond to latent
        variables, columns correspond to ground-truth variables, and values
        are the mutual information (normalized by entropy of the ground-truth
        variable) between the relevant variables. The array must have at least
        two rows, since the metric is only defined for latent representations
        with two or more latent variables.

    Returns
    ----------
    float
        Mean mutual information gap
    '''
    n = x.shape[0]
    assert n >= 2
    mmig = np.max(x,0) - np.sum(x,0) / (n-1)
    return mmig

def dcor(X,Y):
    '''
    Compute distance correlation, a measure of the dependence between samples from two random variables. Distance correlation ranges from 0 (independent) to 1 (dependent).

    https://en.wikipedia.org/wiki/Distance_correlation
    '''
    a = allpairsdistance(X,X)
    b = allpairsdistance(Y,Y)
    A = center_matrix(a)
    B = center_matrix(b)
    dCov = np.sqrt(np.mean(A*B))
    dVarX = np.sqrt(np.mean(A*A))
    dVarY = np.sqrt(np.mean(B*B))
    dCor = dCov / np.sqrt(dVarX*dVarY)
    return dCor

def test_mi():
    N = 5000
    # Sample uniformly from a circle
    theta = np.random.uniform(0,2*np.pi,N)
    X = np.cos(theta)-1
    Y = 2*np.sin(theta)

    def plot_joint_marginals(X,Y):
        plt.figure()
        plt.subplot(2,2,3)
        bins=50
        plt.hist2d(X,Y,bins=bins, range=[[-4,4],[-4,4]])
        plt.subplot(2,2,1)
        plt.hist(X, bins=bins, range=[-4,4])
        plt.subplot(2,2,4)
        plt.hist(Y, bins=bins, range=[-4,4], orientation="horizontal")
        plt.show()
    # plot_joint_marginals(X,Y)
    # print(mi_dependent)
    mi_dependent = mutual_info(X, Y)

    np.random.shuffle(Y)
    # plot_joint_marginals(X,Y)
    # print(mi_independent)
    mi_independent = mutual_info(X, Y)
    assert(mi_independent < mi_dependent)

def test_dcor(display=False):
    N = 5000
    # Sample uniformly from a circle
    theta = np.random.uniform(0,2*np.pi,N)
    X = np.cos(theta)-1
    Y = 2*np.sin(theta)

    def plot_joint_marginals(X,Y):
        plt.figure()
        plt.subplot(2,2,3)
        bins=50
        plt.hist2d(X,Y,bins=bins, range=[[-4,4],[-4,4]])
        plt.subplot(2,2,1)
        plt.hist(X, bins=bins, range=[-4,4])
        plt.subplot(2,2,4)
        plt.hist(Y, bins=bins, range=[-4,4], orientation="horizontal")
        plt.show()
    dcor_dependent = dcor(X, Y)
    if display:
        plot_joint_marginals(X,Y)
        print(dcor_dependent)

    np.random.shuffle(Y)
    if display:
        plot_joint_marginals(X,Y)
        print(dcor_independent)
    dcor_independent = dcor(X, Y)
    assert(dcor_independent < dcor_dependent)

def main():
    import matplotlib.pyplot as plt
    test_mi()
    test_dcor()
    print('All tests passed.')

if __name__ == '__main__':
    main()
