import numpy as np
import os
from sklearn.decomposition import PCA
from scipy.stats import mode
from skdim.id import lPCA, MLE, TwoNN, DANCo, CorrInt, ESS, MADA
from sklearn.decomposition import PCA
from knn import *

# Helper functions #

def exvar_pca(samples, precomputed_knn,
                manifold_name = 'lin_nsp',
                m_params=(100,42),
                params=(100,0),
                root_dir = './data/ide/,',
                savefile=True):
    '''
    Returns the explained_variance_ratios
    '''
    k, seed = params
    N, amb_dim = samples.shape

    if k >= N:
        k = -1 # PCA done for entire dataset in that case
    
    np.random.seed(seed)

    if k == -1:
        '''
        Do the kNN for the full sample.
        '''
        evar = PCA().fit(samples).explained_variance_ratio_
        evar_pw = np.reshape(evar, (evar.shape[0],1) )
    else:
        indices = precomputed_knn[1][:,:k]
        # print(indices.shape)
        evar_pw = []

        for i in range(N):
            # local_nbd = np.array([samples[j] for j in indices])
            local_nbd = samples[indices[i]]
            # print(local_nbd.shape)
            evar = PCA().fit(local_nbd).explained_variance_ratio_
            evar_pw.append(evar)
        
        evar_pw = np.array(evar_pw)
    
    if savefile:
        filename = f'{root_dir}_{manifold_name}_p_{m_params}_pca_p_{params}_exvar.txt'
        np.savetxt(filename, evar_pw)

    return evar_pw

# Example usage 
# N, d1, d2 = 1000, 10, 20
# X = np.random.randn(N,d1)
# print(X.shape)

def abids(X, knns, raw = False, n_jobs=1):
    # Code credit : Erik Thordsen, Erich Schubert
    
    N, dim = X.shape
    dists, indices = knns
    lides = []
    for i in range(N):
        neighbors = X[indices[i]]-X[i]
	    
        # Normalize neighbors and discard neighbors at distance 0
        neighbor_norms = np.linalg.norm(neighbors,axis=1)
        neighbors = neighbors[neighbor_norms > 0] / neighbor_norms[neighbor_norms > 0][:,None]
        
        # For the formula we need the neighborhood size
        k = neighbors.shape[0]
        if neighbors.shape[0] <= neighbors.shape[1]:
		    # Compute the cosines between neighbors (dot product for normalized vecs)
            cosines = neighbors.dot(neighbors.T)
            ssq_eigvals = np.sum(np.square(cosines))
        else: # Alternative formula that is computationally more efficient
		    # This matrix has the same sum of squared eigenvalues as the cosine matrix
            mat = neighbors.T.dot(neighbors)
            ssq_eigvals = np.sum(np.square(mat))
	    # Compute the LID estimate
        if raw: 
            lid =(k**2-k)/(ssq_eigvals-k)
        else:
            lid = k**2 / ssq_eigvals
        lides.append(lid)
    lides = np.array(lides)

    return lides

### Helper functions end ###


### lPCA ###

def ide_lpca(samples=None, precomputed_knn=None,
            manifold_name = 'lin_nsp',
            m_params=(100,42), 
            ver = 'maxgap', params = (100,0.1,0),
            n_jobs=1, root_dir = '.data/ide_data/'):
    '''
    Params : (k: int, eps: float or None, seed: int)
    '''
    k, eps, seed = params
    exvar_params = (k, seed)
    
    filename = f'{root_dir}_{manifold_name}_p_{m_params}_pca_p_{exvar_params}_exvar.txt'
    
    if os.path.exists(filename):
        evar_pw =np.loadtxt(filename)
        N = evar_pw.shape[0]
    else:
        assert samples is not None, 'File does not exist, need samples to proceed!'
        evar_pw = exvar_pca(samples, precomputed_knn, manifold_name, m_params=m_params, params=(k,seed))
    
    if samples is not None:
        N,amb_dim = samples.shape
    
    if ver == 'maxgap':
        gaps = evar_pw[:,:-1]/(evar_pw[:,1:] + 1e-6)
        lides = np.nanargmax(gaps, axis=1) + 1
    elif ver == 'ratio':
        cum_ratio = np.cumsum(evar_pw,axis=1)/np.sum(evar_pw,axis=1,keepdims=True)
        lides = np.zeros((N))
        for i in range(N):
            lides[i] = np.searchsorted(cum_ratio[i,:], 1-eps, side='right')+1
    elif ver == 'fo':
        lides = np.zeros((N))
        for i in range(N):
            meig = np.max(evar_pw[i,:])*(1-eps)
            lides[i] = np.searchsorted(evar_pw[i,:], meig, side='right')+1
    else:
        available_ver = ['maxgap', 'ratio', 'fo' ]
        raise ValueError(f'{ver} not implemented. Available methods are {available_ver}')

    info_dict = {'lides' : lides,
                 'params' : params,
                 'ver' : ver,
                 'manifold': manifold_name,
                 'm_params':m_params,
                 'exp_var_ratios' : evar_pw}
    
    return info_dict

# Example usage
# X = np.random.randn(1000,30)
# knn = precompute_knn(X,kmax = 20)
# info_dict = ide_lpca(X,knn,ver='maxgap')
# lides = info_dict['lides']

def ide_mle(samples=None, precomputed_knn=None,
             manifold_name = 'lin_nsp',
             m_params=(100,42), 
             params = (100,0.0,0),
             n_jobs=1):
    '''
    Maximum Likelihood Estimation.
    Params: (k: int, sigma:float, seed: int)
    '''

    k, sigma, seed = params
    estimator = MLE(sigma=sigma)
    lides = estimator.fit(X=samples, 
                      precomputed_knn_arrays=precomputed_knn,
                      n_jobs=n_jobs).dimension_pw_
    
    info_dict = {'lides':lides,
                 'params': params,
                 'manifold': manifold_name,
                 'm_params':m_params,}

    return info_dict

# Example usage
# X = np.random.randn(1000,30)
# knn = precompute_knn(X,kmax = 30)
# knn = (knn[0][:,1:],knn[1][:,1:])
# info_dict = ide_mle(X,knn)
# lides = info_dict['lides']

def ide_danco(samples=None, precomputed_knn=None,
             manifold_name = 'lin_nsp',
             m_params=(100,42), 
             params = (5,7,0),
             n_jobs=1):
    '''
    Distance and Angle-Norm concentration.
    Params: (k : int, k1 : int, seed: int)
    Esnure k1 > k+1
    '''
    k, k1, seed = params

    estimator = DANCo(k=k-2)
    lides = estimator.fit_pw(X=samples, 
                      n_neighbors=precomputed_knn[1].shape[0]-2,
                      n_jobs=n_jobs).dimension_pw_
    
    info_dict = {'lides':lides,
                 'params': params,
                 'manifold': manifold_name,
                 'm_params':m_params,}

    return info_dict

# Example usage
# X = np.random.randn(50,30)
# knn = precompute_knn(X,kmax = 20)
# info_dict = ide_danco(X,knn)
# lides = info_dict['lides']

def ide_twonn(samples=None, precomputed_knn=None,
             manifold_name = 'normal',
             m_params=(100,42), 
             params = (100,0.1,0),
             n_jobs=1):
    '''
    Two Nearest neighbors ratio follow exponential PDF.
    Params: (k: int, dFrac : float, seed: int)
    '''
    k, dFrac, seed = params

    estimator = TwoNN(discard_fraction=dFrac)
    lides = estimator.fit_pw(X=samples, 
                      precomputed_knn=precomputed_knn[1],
                      n_neighbors=precomputed_knn[1].shape[0]-2,
                      n_jobs=n_jobs).dimension_pw_
    
    info_dict = {'lides':lides,
                 'params': params,
                 'manifold': manifold_name,
                 'm_params':m_params,}

    return info_dict

# Example usage
# X = np.random.randn(1000,60)
# knn = precompute_knn(X,kmax = 20)
# knn = (knn[0][:,1:], knn[1][:,1:])
# info_dict = ide_twonn(X,knn)
# lides = info_dict['lides']

def ide_corrint(samples=None, precomputed_knn=None,
             manifold_name = 'normal',
             m_params=(100,42), 
             params = (10,20,0),
             n_jobs=1):
    '''
    Fractal dimension estimation.
    Params: (k1, k2, seed)
    '''
    k1, k2, seed = params

    estimator = CorrInt(k1=k1,k2=k2)
    lides = estimator.fit_pw(X=samples, 
                      n_neighbors=precomputed_knn[1].shape[0]-2,
                      n_jobs=n_jobs).dimension_pw_
    
    info_dict = {'lides':lides,
                 'params': params,
                 'manifold': manifold_name,
                 'm_params':m_params,}

    return info_dict

# Example usage
# X = np.random.randn(1000,30)
# knn = precompute_knn(X,kmax = 20)
# knn = (knn[0][:,1:], knn[1][:,1:])
# info_dict = ide_corrint(X,knn)
# lides = info_dict['lides']

def ide_abid(samples=None, precomputed_knn=None,
             manifold_name = 'normal',
             m_params=(100,42), 
             params = (20,0),
             n_jobs=1):
    '''
    Estimation based on moments of cosine distributions.
    Params: (k, seed)
    '''
    k1, seed = params

    lides = abids(X=samples, 
                knns=precomputed_knn,
                n_jobs=n_jobs)
    
    info_dict = {'lides':lides,
                 'params': params,
                 'manifold': manifold_name,
                 'm_params':m_params,}

    return info_dict

# Example usage
# X = np.random.randn(1000,30)
# knn = precompute_knn(X,kmax = 20)
# info_dict = ide_abid(X,knn)
# lides = info_dict['lides']
