import numpy as np
import scipy.sparse as sp

from sklearn import linear_model
from sklearn.neighbors import NearestNeighbors
from scipy.optimize import curve_fit

from typing import List, Union, Tuple

from tqdm import tqdm

# This function is based on the following stackoverflow ans:
# https://stackoverflow.com/questions/31097247/remove-duplicate-rows-of-a-numpy-array
def unique (matrix, axis = 0):
    '''
        This routine removes the duplicated rows in the X data. This is used
        in the ID estimate because we care about nearest neighbour information
        and duplicated measures results in zero distances which yields infinities
        ora NaN in the mu = r2 / r1 estimation.

        Returns a dense matrix with the unique rows (axis = 0)
        or columns (axis = 1) of an input sparse matrix sp_matrix
    '''
    # Perform lex sort and get sorted data
    if axis == 1:
        matrix = matrix.T 

    sorted_idx = np.lexsort (matrix.T) 
    sorted_data =  matrix [sorted_idx,:] 

    # Get unique row mask
    row_mask = np.append ([True], np.any (np.diff (sorted_data, axis = 0), 1)) 

    # Get unique rows
    return sorted_data [row_mask]

# This function is based on the following stackoverflow ans:
# https://stackoverflow.com/questions/46126840/get-unique-rows-from-a-scipy-sparse-matrix
def sp_unique (sp_matrix, axis = 0):
    '''
        This routine removes the duplicated rows in the X data. This is used
        in the ID estimate because we care about nearest neighbour information
        and duplicated measures results in zero distances which yields infinities
        ora NaN in the mu = r2 / r1 estimation.

        Returns a sparse matrix with the unique rows (axis = 0)
        or columns (axis = 1) of an input sparse matrix sp_matrix
    '''
    if axis == 1:
        sp_matrix = sp_matrix.T 

    old_format = sp_matrix.getformat () 
    dt = np.dtype (sp_matrix) 
    ncols = sp_matrix.shape [1] 

    if old_format != 'lil':
        sp_matrix = sp_matrix.tolil () 

    _, ind = np.unique (sp_matrix.data + sp_matrix.rows, return_index = True) 
    rows = sp_matrix.rows [ind] 
    data = sp_matrix.data [ind] 
    nrows_uniq = data.shape [0] 

    sp_matrix = sp.lil_matrix ((nrows_uniq, ncols), dtype = dt) 
    sp_matrix.data = data 
    sp_matrix.rows = rows 

    ret = sp_matrix.asformat (old_format) 
    if axis == 1:
        ret = ret.T 

    return ret 

class IntrinsicDim():
    '''
        This is the Intrinsic Dimension estimator object. It estimates the 
        Intrinsic Dimension of data represented by an (N x D) matrix X, 
        where N is the number of data points and D is the dimension of
        the original space. The object constructor accepts the following
        optional parameters

        Kwargs:
        frac [0.9] : fraction of the data considered for the linear fit which 
                     yields the Intrinsic Dimension estimate. Very high-values
                     of mu = r1 / r2 will be discarded because they pullute the
                     linear regression. Default is frac = 0.9.

        metric ['euclidean'] : The type of metric consider for Nearest Neighbour
                               estimation. Should be one accepted by the scipy
                               NN estimator.

        squeeze [True] : Optional parameter for removing duplicates in data.

        verbose [False] : Optional parameter for verbose output.
    '''

    def __init__(self, 
                frac    : float = 0.9, 
                metric  : str   = 'euclidean', 
                squeeze : bool  = True) -> None:
        
        # Store the essential parameters
        self.frac    = frac
        self.metric  = metric
        self.squeeze = squeeze

    def __call__(self, X : np.array, return_full_info : bool = False) -> Union[float, Tuple]:
        '''
            This is the core ID function. It estimates the Intrinsic Dimension of
            data represented by NxD matrix X, where N is the number of data points
            and D is the dimension of the original space.

            Args:
            X: 2-D Matrix (N, D) representing the initial dataset of N point in D
            dimensional space.

            Returns:
            d: Intrinsic Dimension estimate of the data.
            R: R squared value of the linear fit
        '''

        # Here we remove duplicated points
        if self.squeeze: X = sp_unique (X) if sp.issparse(X) else unique (X) 
        
        N, D = np.shape (X) 

        # First we compute first and second nearest neighbors using scipy.neighbors
        nn = NearestNeighbors (n_neighbors = 3, metric = self.metric, n_jobs = -1) 
        M, _ = nn.fit (X).kneighbors (X) 

        r1, r2 = M [:, 1], M [:, 2] 

        # Here we compute mu, defined as the ratio of second to first n_neighbor
        mu = np.sort (r2 / r1, kind = 'quicksort') 

        F_emp = np.arange (N, dtype = np.float64) / N 
        
        # Here we compute the x and y of our linear regression and we only use the
        # firsrt f-fraction of the data, so the exclude the top 1 - f
        x, y = np.log (mu) [:int (self.frac * N)].reshape (-1, 1),\
              -np.log (1. - F_emp)[:int (self.frac * N)].reshape (-1, 1) 
        
        # Here we instantiate the Linear Regressor
        lr = linear_model.LinearRegression (n_jobs = -1).fit (x, y) 

        # Here we return the Intrinsic Dimension, a.k.a the angular coefficient of
        # the fit and the score of the fit measured as the R-squared
        ID, score = lr.coef_[0, 0], lr.score (x, y) 
        
        return (ID, score, np.log(mu), -np.log (1. - F_emp), lr) if return_full_info else ID


    def block_estimate (self, X : np.array, B : List[int]):
        '''
            Perform block-analysis of the Intrinsic Dimension estimate of dataset
            X subdivided sequentially in b ∈ B blocks

            Args:
            X: 2-D Matrix (N, D) representing the initial dataset of N point in D
            dimensional space.
            B: List of number of block into which X is to be subdivided.

            f: fraction of the data considered for the linear fit which yields the
            Intrinsic Dimension estimate. Very high-values of mu = r1 / r2 will
            be discarded because they pullute the linear regression. Default is
            f = 0.9

            Returns:
            d: List of means of Intrinsic Dimension estimates of the data subdivided
            in blocks.
            std_d: List of Standard Deviations of ID estimates in each block.
            R: List of R-squared values of the linear fits
        '''

        N, D = np.shape (X) 

        d, std_d, R = [], [], [] 

        # Here we scan our block-number list
        for b in (tqdm (B) if self.verbose else B):
            # Here we subdived the original X dataset into b subblocks and for each
            # we estimate the ID. The result will be the mean of the estimates.
            sigma = np.random.permutation (N) 

            # Here we compose the b sub-blocks
            idxs = np.array_split (sigma, b) 

            _d = np.zeros (b) 
            _R = np.zeros (b) 

            for i, idx in enumerate (idxs):
                # Here we compute the estimate for this block
                _d [i], _R [i] = self.estimate (X [idx]) 

            # Here we update our estimes with the mean and the std
            d.append (np.mean (_d)) 
            R.append (np.mean (_R)) 

            std_d.append (np.std (_d)) 

        return d, std_d, R 


    def _estimate_curv (self, X, metric = 'euclidean', squeeze = True):
        def f (r, d, R): return -x**(-3 - d) * (d - (2 + d) * x**2) 

        N, D = np.shape (X) 

        # Here we remove duplicated points
        if squeeze:
            X = sp_unique (X) if sp.issparse(X) else unique (X) 

        # First we compute first and second nearest neighbors using scipy.neighbors
        nn = NearestNeighbors (n_neighbors = 3, metric = metric, n_jobs = -1) 
        M, _ = nn.fit (X).kneighbors (X) 

        r1, r2 = M [:, 1], M [:, 2] 

        # Here we compute the empirical f (mu), defined as the histogram of the
        # ratio of second to first n_neighbor
        f_mu, bins = np.histogram (r2 / r1, bins = 'auto', density = True) 

        # Here we compute the x for our empirical function
        x = np.array([0.5 * (a + b) for a, b in zip (bins [:-1], bins [1:])]) 

        # Here we fit f_mu with the theoretical curve to extract the optimal d
        d, std_d = curve_fit (f, x, f_mu) 


        # import matplotlib.pyplot as plt
        #
        # plt.plot (f_mu) 
        # plt.show () 


        return d [0], np.sqrt (std_d[0, 0]) 
