import numpy as np
from scipy.spatial.distance import cdist
from sklearn.preprocessing import StandardScaler
from joblib import Parallel, delayed


## Helper functions for DMF
from sklearn.neighbors import NearestNeighbors
from pydiffmap import diffusion_map as dm

def value2trans(matrix):
    """
    A helper function for Value-to-rank transformation.
    """
    # Step 1: Sort each row of the matrix and get the sorted indices
    sorted_indices = np.argsort(matrix, axis=1)

    # Step 2: Create ordinal numbers (rankings) for each row
    ordinal_numbers = np.tile(np.arange(1, matrix.shape[1] + 1), (matrix.shape[0], 1))

    # Step 3: Use the sorted indices to map the ranks back to the original positions
    matrix_t = np.zeros_like(matrix)
    for i in range(matrix.shape[0]):
        matrix_t[i, sorted_indices[i, :]] = (ordinal_numbers[i, :] / matrix.shape[1]) ** 0.5

    return matrix_t

def diffusion_neighborhood(X, knn, K = 15, epsilon = 'bgh'):
    if K<4:
        K=4
    # Create a diffusion map object
    print("calculating diffusion_neighborhood", X.shape)
    mydmap = dm.DiffusionMap.from_sklearn(n_evecs=K, epsilon = epsilon)  

    X_dmap = mydmap.fit_transform(X)  # Get the diffusion coordinates
    
    # Use Euclidean distance in diffusion space to find neighbors
    nbrs = NearestNeighbors(n_neighbors=knn).fit(X_dmap)
    neighbors = nbrs.kneighbors(X_dmap, return_distance=False)
    return neighbors

def diff_corr_SNN_neighborhood(sample, k, diffusion = True, K_diff = 15, epsilon = 'bgh'):
    """
    Computes the shared nearest neighbor (SNN) graph using either a diffusion-based or correlation-based neighborhood function.

    Args:
        sample (numpy.ndarray): The input data matrix where rows represent samples 
                                and columns represent features. Shape is (N, D), where 
                                N is the number of samples and D is the number of features.
        k (int): The number of nearest neighbors to consider for each sample.
        diffusion (bool, optional): If True, the function applies a diffusion-based neighborhood approach. 
                                    If False, it applies a correlation-based neighborhood approach. Default is True.

    Returns:
        numpy.ndarray: An array of indices representing the k-nearest neighbors for each sample 
                       after applying the shared nearest neighbor (SNN) algorithm.

    Example Usage:
        # Use diffusion-based SNN
        neighbors_diff = diff_corr_SNN_neighborhood(sample_data, k=15, diffusion=True, K_diff=7)

        # Use correlation-based SNN
        neighbors_corr = diff_corr_SNN_neighborhood(sample_data, k=15, diffusion=False)
    """

    if diffusion:
        # Apply the diffusion neighborhood function
        Nb_dist = diffusion_neighborhood(sample, k, K_diff, epsilon = epsilon)
    else:
        corr_matrix = np.corrcoef(sample)
        Nb_dist = np.argsort(-corr_matrix, axis=1)[:, :k]  # take k nearest neighbors with highest correlation


    # Rebuild DI matrix based on diffusion neighborhood
    n = Nb_dist.shape[0]
    DI = np.zeros((n, n))

    for ii in range(n):
        for jj in Nb_dist[ii, :]:
            DI[ii, jj] = len(np.intersect1d(Nb_dist[ii, :k], Nb_dist[jj, :k]))

    # Apply the same transformation as before to finalize the matrix
    DI = (k - np.maximum(DI, DI.T)) / k
    D = DI

    # Nb_dist now contains the updated neighborhood structure in diffusion space
    Nb_dist = np.argsort(D, axis=1)[:, :k]
    return Nb_dist



### Run DMF

def DMF(sample, knn = 15, trans = True, K_diff=15, epsilon = 'bgh'):
    '''
        Applies manifold fitting to the input data `sample`.

        Args:
            sample (numpy.ndarray): The input data to process. Shape is expected to be (n, p) 
                                    where n is the number of samples and p is the dimensionality.
            knn (int): The number of nearest neighbors to consider for Shared Nearest Neighbors(SNN).
            trans (bool, optional): If True, a predefined transformation (`value2trans`) is applied 
                                    to the data before manifold fitting to control variance. Default is True.
            K_diff (int): The number of diffusion map components are retained (or significant directions of  the underlying manifold structure in the data).

        Returns:
            Mout (numpy.ndarray): The manifold fitted result with a same shape as `sample`.

        Details:
            1. The function calculates the Diffusion map based Shared Nearest neighbors for each sample point.
            2. If `trans` is True, the function applies the `value2trans` transformation.
            3. For each sample point, the mean of its neighbors is computed, and an iterative 
            adjustment is applied to optimize the location of the sample point, ensuring that 
            the total squared distance between the sample point and its neighbors is minimized.
            4. The function uses parallel processing to accelerate the iteration over all sample points.

        Example Usage:
            result = DMF(sample_data, knn=5, trans=True)
    '''

    print(sample.shape)
    Mout = np.zeros(sample.shape)
    N = sample.shape[0]
    Nb_dist =  diff_corr_SNN_neighborhood(sample, knn, diffusion = True, K_diff = K_diff, epsilon = epsilon)
    # Transform sample using a predefined function 'value2trans'
    sample_ = sample
    if trans:
        sample_ = value2trans(sample)

    # Define function for each iteration
    def process_sample_v1(ii):
        BNbr = sample_[Nb_dist[ii, :], :]

        xbar = np.mean(BNbr, axis=0)
        d = xbar - sample_[ii, :]

        weights = np.array([-0.1, -0.05, 0, 0.05, 0.1])
        x_final = xbar

        ds_final = np.sum(cdist([x_final], BNbr)**2)

        for pp in range(5):
            x_temp = xbar + weights[pp] * d
            ds = np.sum(cdist([x_temp], BNbr)**2)

            if ds <= ds_final:
                x_final = x_temp
                ds_final = ds

        return ii, x_final
    print("Proceed the main loop", sample_.shape)

    # Parallel loop using joblib
    results_v1 = Parallel(n_jobs=-1)(delayed(process_sample_v1)(ii) for ii in range(N))

    # Assign results back to Mout_v1
    for ii, x_final in results_v1:
        Mout[ii, :] = x_final
    return Mout

def yao2(sample, knn = 15, trans = True):
    print(sample.shape)
    Mout = np.zeros(sample.shape)
    N = sample.shape[0]
    Nb_dist =  diff_corr_SNN_neighborhood(sample, knn, diffusion = False)
    # Transform sample using a predefined function 'value2trans'
    sample_ = sample
    if trans:
        sample_ = value2trans(sample)

    # Define function for each iteration
    def process_sample_v1(ii):
        BNbr = sample_[Nb_dist[ii, :], :]

        xbar = np.mean(BNbr, axis=0)
        d = xbar - sample_[ii, :]

        weights = np.array([-0.1, -0.05, 0, 0.05, 0.1])
        x_final = xbar

        ds_final = np.sum(cdist([x_final], BNbr)**2)

        for pp in range(5):
            x_temp = xbar + weights[pp] * d
            ds = np.sum(cdist([x_temp], BNbr)**2)

            if ds <= ds_final:
                x_final = x_temp
                ds_final = ds

        return ii, x_final

    # Parallel loop using joblib
    results_v1 = Parallel(n_jobs=-1)(delayed(process_sample_v1)(ii) for ii in range(N))

    # Assign results back to Mout_v1
    for ii, x_final in results_v1:
        Mout[ii, :] = x_final
    return Mout

