'''
Matlab Code by Shira Kritchman and Boaz Nadler, 2008, Weizmann Institute of Science
Converted to Python 3.7 by Jaewoong Choi, 2022, Korea Institute for Advanced Study

FOR MORE DETAILS SEE:
    S. Kritchman and B. Nadler. Determining the number of components in a factor model
    from limited noisy data, Chem. Int. Lab. Sys. 2008    
'''
import numpy as np

def KN_rankEst(ell,n,beta,alpha=0.5,max_kk=None):
    '''
    DESCRIPTION:
        This function gets as input the eigenvalues of an SCM 
        (sample covariance matrix), and outputs an estimation of its 
        pseudorank, under the assumption of uncorrelated homoscedastic noise.

    INPUT:
        ell     -  vector of eigenvalues of the SCM, of length p
        n       -  number of samples
        beta    -  indicator for real (1) or complex (2) valued observations
        alpha   -  confidence level, given in percents
        max_kk  -  maximal possible value for pseudorank
        
    OUTPUT:
        K - pseudorank estimation for the SCM
        sigma_hat - estimate of noise variance. 
    '''

    p = len(ell)
    if max_kk is None: max_kk = min(n,p)-1
    max_kk = min(max_kk, min(p,n)-1)
        
    s_Wishart = KN_s_Wishart(alpha,beta);
    sigma_arr = np.zeros(max_kk) 
    
    for kk in range(1, max_kk+1):
        (mu_np, sigma_np) = KN_mu_sigma(n,p-kk,beta)
        sig_hat_kk = KN_noiseEst(ell,n,kk)
        sigma_arr[kk-1] = sig_hat_kk
        at_least_kk_signals = n * ell[kk-1] > sig_hat_kk * (mu_np + s_Wishart * sigma_np)
        if not at_least_kk_signals: break
    
    K = kk-1
    if K > 0:
        sigma_hat = sigma_arr[K-1]
    else:
        sigma_hat = sum(ell) / p
        
    return K, sigma_hat

def KN_mu_sigma(n,p,beta):
    '''    
    DESCRIPTION:
        This function computes the parameters mu_np and sigma_np
        which are used to normalize ell_1, the largest eigenvalue of
        a Wishart matrix. After the normalization the distribution of
        ell_1 converges to a Tracy-Widom distribution:
        Pr{ell_1 > (mu_np + s sigma_np)} --> F_beta(s)
        These values are used in the algorithm for rank estimation, KN_rankEst.

    INPUT:
        n       -  number of samples
        p       -  dimension of samples
        beta    -  indicator for real (beta=1) or complex (beta=2) valued observations

    OUTPUT:
        [mu_np,sigma_np]
    --------------------------------------------
    FOR MORE DETAILS ON THE COMPUTATION OF mu_np AND sigma_np SEE:
        I. M. Johnstone, High Dimensional Statistical Inference and Random
        Matrices, Proc. International Congress of Mathematicians, 2006.

        N. El Karoui, A rate of convergence result for the largest eigenvalue of
        complex white Wishart matrices, Annals of Probability, 34(6):2077-2117,
        2006.
    --------------------------------------------
    '''
    assert beta in [1, 2]

    if beta==1:
        mu_np = (np.sqrt(n-1/2) + np.sqrt(p-1/2))**2
        sigma_np = np.sqrt(mu_np) * (1/np.sqrt(n-1/2) + 1 / np.sqrt(p-1/2) )**(1/3)
    else:
        P, N = min(n,p), max(n,p)  
        N_plus, P_plus = N+1/2, P+1/2
        Nm1_plus, Pm1_plus = N-1+1/2, P-1+1/2

        mu_Nm1P_temp    = (np.sqrt(Nm1_plus) + np.sqrt(P_plus  ))**2
        mu_NPm1_temp    = (np.sqrt(N_plus)   + np.sqrt(Pm1_plus))**2
        sigma_Nm1P_temp = (np.sqrt(Nm1_plus) + np.sqrt(P_plus  ))*(1/np.sqrt(Nm1_plus) + 1/np.sqrt(P_plus  ))**(1/3)
        sigma_NPm1_temp = (np.sqrt(N_plus)   + np.sqrt(Pm1_plus))*(1/np.sqrt(N_plus)   + 1/np.sqrt(Pm1_plus))**(1/3)

        gamma_NP = (mu_Nm1P_temp * sigma_NPm1_temp**0.5) / (mu_NPm1_temp * sigma_Nm1P_temp**0.5)
        sigma_np = (1+gamma_NP) / (1/sigma_Nm1P_temp + gamma_NP/sigma_NPm1_temp)
        mu_np    = (1/sigma_Nm1P_temp**0.5 + 1/sigma_NPm1_temp**0.5) / \
                        (1/(mu_Nm1P_temp * sigma_Nm1P_temp**0.5) + 1/(mu_NPm1_temp * sigma_NPm1_temp**0.5))
    return mu_np, sigma_np

def KN_s_Wishart(alpha,beta):
    '''
    DESCRIPTION:
        This function computes an approximate inverse of the 
        TW (Tracy-Widom) distribution F_beta: 
        if X ~ F_beta, then the function returns a value s_Wishart such that 
        Pr{X > s_Wishart} ~ alpha/100.
        This value is used in the algorithm for rank estimation, KN_rankEst.

    INPUT:
        alpha   -  confidence level (given in percentage)
        beta    -  indicator for real (1) or complex (2) TW distribution

    OUTPUT:
        s_Wishart - threshold value for TW distribution with confidence level alpha
    --------------------------------------------
    Here we use the asymptotics of the TW distribution for large values of x
    These can be found, for example, in 
    J. Baik, R. Buckingham and J. DiFranco,
    Asymptotics of the Tracy-Widom distributions and the total integral of a Painleve II function,
    Comm. Math. Phys., vol. 280, no. 2, pp. 463--497, 2008.
    --------------------------------------------
    '''
    assert beta in [1, 2]
    
    if beta == 1:
        s_Wishart = (-3/2 * np.log(4*np.sqrt(np.pi) * alpha/100 ))**(2/3)
    else:
        s_Wishart = (-3/4 * np.log(16 * np.pi * alpha/100) )**(2/3)
    return s_Wishart

def KN_noiseEst(ell,n,kk):
    '''
    DESCRIPTION:
        This function gets as input the eigenvalues of an SCM 
        (sample covariance matrix) and an assumed value of its pseudorank,
        and outputs an estimation of the noise variance,
        under the assumption of uncorrelated homoscedastic noise.
        This value is used in the algorithm for rank estimation, KN_rankEst

    INPUT:
        ell     -  vector of eigenvalues of the SCM, of length p
        n       -  number of samples
        kk      -  assumed rank

    OUTPUT:
        sig_hat_kk - Estimate of the unknown (squared) noise variance, sigma^2. 
    '''
    
    max_iter = 30
    eps_threshold = 1e-5
    p = len(ell)

    sigma_0 = 1/(p-kk) * sum(ell[kk:]) * 1 / (1-kk / n) 
    
    for counter in range(1, max_iter+1):
        # solve quadratic equation for rho, given sigma and eigenvalues
        tmp = ell[:kk] + sigma_0 - (p-kk)/n*sigma_0
        if min(tmp**2 - 4 * ell[:kk]*sigma_0) < 0:  # otherwise get complex valued solutions
            break
        Rho_est = np.zeros(kk) 
        Rho_est = ( tmp + np.sqrt( tmp**2 - 4 * ell[:kk]*sigma_0)  ) / 2

        if min(ell[:kk] - Rho_est) < 0:
            #print(f'MAJOR ERROR CONSISTENT NOISE kk {kk}')
            #print(ell[:kk], Rho_est)
            raise
        Delta_l_rho = np.copy(ell[:kk] - Rho_est)
        Delta_l_rho[Delta_l_rho < 0] = 0
        sigma_new = 1/(p-kk) * ( sum(ell[kk:]) + sum(Delta_l_rho) ) 
        if abs(sigma_new - sigma_0)/sigma_0 < eps_threshold:
            break
        else:
            sigma_0 = sigma_new
    sig_hat_kk = sigma_0
    return sig_hat_kk

