import numpy as np
from scipy.stats import multivariate_normal
from scipy.stats import gaussian_kde
from sklearn.neural_network import MLPRegressor
from sklearn.neighbors import NearestNeighbors


def Psi_Trans(X):
    X_new = list(X)
    for i in range(len(X)):
        for j in np.arange(i,len(X),1):
            X_new.append(X[i]*X[j])
    return(X_new)


def MC_TV_Baseline(mu_1,Sigma_1,mu_2,Sigma_2):
    x_sample = np.concatenate([np.random.multivariate_normal(mean=mu_1, cov=Sigma_1, size=100000),
    np.random.multivariate_normal(mean=mu_2, cov=Sigma_2, size=100000)])
    # calculate pdf
    mvn_P = multivariate_normal(mean=mu_1, cov=Sigma_1)  ## P(x) real data distribution
    mvn_Q = multivariate_normal(mean=mu_2, cov=Sigma_2)  ## Q(x) synthetic data distribution
    P_val = np.apply_along_axis(mvn_P.pdf, axis=1, arr=x_sample)
    Q_val = np.apply_along_axis(mvn_Q.pdf, axis=1, arr=x_sample)
    TV_est_baseline = np.abs(P_val / (P_val+Q_val) -Q_val / (P_val+Q_val)).mean()
    return(TV_est_baseline)


def Dist_TV(x_train, x_test, y_train, y_test,seed = 1):
    # CL_TV
    # x - transformation
    x_train_trans = np.apply_along_axis(Psi_Trans, axis=1, arr=x_train)
    x_test_trans = np.apply_along_axis(Psi_Trans, axis=1, arr=x_test)
    f_hat = MLPRegressor(hidden_layer_sizes=1,
                        tol=1e-7, alpha=0.00001,
                        activation='logistic', random_state=seed, max_iter=10000)
    f_hat.fit(x_train_trans, y_train)
    # predict labels for the testing data
    y = f_hat.predict(x_test_trans)
    y_pred = [1 if val > 0.5 else 0 for val in y]
    # calculate misclassification rate
    misclassification_rate = np.mean(y_pred != y_test)
    TV_est_CL = abs(1-2*misclassification_rate)
    return(TV_est_CL)


def KDE_TV(x_real, x_syn):
    kde_real = gaussian_kde(x_real.T, bw_method='silverman')
    kde_syn = gaussian_kde(x_syn.T, bw_method='silverman')
    sample_size = 50000
    sample_1 = kde_real.resample(size = int(sample_size/2)).T
    sample_2 = kde_syn.resample(size = int(sample_size/2)).T
    x_sample = np.concatenate((sample_1, sample_2))
    # print(x_sample.shape)
    density_real = kde_real(x_sample.T)
    density_syn = kde_syn(x_sample.T)
    KDE_est_tv = np.abs(density_real / (density_real + density_syn) - density_syn / (density_real + density_syn)).mean()
    # print(KDE_est_tv)
    return(KDE_est_tv)

def KNN(x_real, x_syn):
    T = x_syn.shape[0] # number of samples in x_syn
    M = min(int(T/2), int(x_real.shape[0]/2)) 
    N = T-M
    X_sample = x_syn[range(M), :]
    Y_1 = x_real[range(M), :] # take M samples of x_real
    Y_2 = x_syn[range(M,T), :] # divide x_syn into two sets: X and Y_1
    d = x_real.shape[1] # dimension
    k = int(M**0.5) # number of neighbors, optimal choice of k
    # Initialize the NearestNeighbors model
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(Y_1)
    # Find k nearest neighbors and distances for each element in X_sample
    distances, indices = nbrs.kneighbors(X_sample)
    rho_1 = distances[:, -1]
    # Initialize the NearestNeighbors model
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(Y_2)
    # Find k nearest neighbors and distances for each element in X_sample
    distances, indices = nbrs.kneighbors(X_sample)
    rho_2 = distances[:, -1]
    L = (rho_2/ rho_1)**d
    g = 0.5 * np.abs(L- 1)
    value = np.mean(g)
    return value


def NNRE(x_real, x_syn):
    # Z = [X; Y] where X and Y are matrices with N rows and d columns.
    # N is the number of samples in X and Y, and d is the dimension.
    # IDX is a matrix, where rows are different nodes and columns are indices of KNNs.
    # The first index is the point itself, so we take k+1 nearest neighbors.
    Z = np.vstack((x_real, x_syn))
    N = x_real.shape[0] # number of samples in x_real
    M = x_syn.shape[0] # number of samples in x_syn
    d = x_real.shape[1] # dimension
    k = int(N**0.5) # number of neighbors
    # Calculate k nearest neighbors
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(Z)
    distances, indices = nbrs.kneighbors(Z)
    indices_0 = indices[:,1::]
    # For each row (node), obtain how many of KNN are of the set X (those who have indices < N)
    Temp = (indices <= N)
    # Temp2 is the number of indices from X set.
    Temp2 = np.sum(Temp, axis=1)
    Rat = (Temp2) / (k - Temp2 + 1)
    Temp3 = 0.5 * np.abs(M/N * (Rat[N:N+M]) - 1)
    #Temp3 = 0.5 * np.abs(np.array([max(i, 1 / 2) for i in Rat[N:N + M]])-1)
    # Average over KNN ratios of Y set
    value = np.mean(Temp3)
    return(value)


