import numpy as np
import numpy.linalg as LA
from sympy import re, im, I, E, Symbol, sqrt
import argparse


def g_CauchyK_num(S):
    z = Symbol('z')
    ret = 0
    N = len(S)
    
    for j in range(N):

        ret += 1/(z + S[j] - I*np.sqrt(1/(2*N)) )
        ret += 1/(z - S[j]- I*np.sqrt(1/(2*N)) )
    
    return ret/(2*N)

def Estimator(S_s, gX, gS, SNR, alpha):
    
    N = len(S_s)
    
    output_X = np.zeros(N)
    output_XX =np.zeros(N)
    
    z = Symbol('z')
    
    dfr = 32
    if SNR == 1:
        dfr = 64
    elif SNR == 2:
        dfr = 128
    elif SNR == 3:
        dfr = 256
    elif SNR == 4:
        dfr = 512
    elif SNR == 5:
        dfr = 1024
        
    for i in range(N):
        
        #### optimal eigenvalue for X
        zz = S_s[i] -  I*np.sqrt(dfr/(2*N))
        gS_eval = gS.subs(z,zz).evalf()
        zeta = gS_eval + ((1-alpha)/alpha)*(1/zz)
        
        Z = (zz/zeta -1)/SNR
        
        Est = gX.subs(z,sqrt(Z)).evalf() + gX.subs(z,-sqrt(Z)).evalf()
        
        output_X[i] = im(((Est/zeta)/(2*SNR*im(gS_eval))).evalf())
        
        #### optimal eigenvalue for X^2
        output_XX[i]  = ( -1 + 1 /( alpha * ( im(gS_eval)**2 + ( re(gS_eval) + (-1 + 1/alpha )/S_s[i] )**2 ) ) )/SNR
    
    return output_X, output_XX


def main():
    
    z = Symbol('z')
    p = argparse.ArgumentParser()

    p.add_argument('-a', type=float)
    p.add_argument('-s', type=float)
    p.add_argument('-p', type=str)
    
    args = p.parse_args()
    
    a = args.a
    prior = args.p
    SNR = args.s

    N = 2000
    M = int(N/a)
    
    Ex = 10
    
        
    E_X_oracle = np.zeros(Ex)
    E_X_RIE = np.zeros(Ex)
    E_X_sqXX = np.zeros(Ex)
        
    E_XX_oracle = np.zeros(Ex)
    E_XX_RIE = np.zeros(Ex)

    for i in range(Ex):
        
        if prior == 'Wigner':
            X = np.triu(np.random.normal(0, 1, (N,N)))
            X = X + np.transpose(X) + np.diag(np.random.normal(loc=0, scale=np.sqrt(2), size=(N)))
            X = X/np.sqrt(N)
            X = X + 3*np.eye(N)
                
            gX =  (z - 3 - sqrt(z-5)* sqrt(z-1))/2
                
        elif prior == 'Wishart':
            
            X = np.random.randn(N,4*N)
            X = X@np.transpose(X)/N
            X_s, U_x = LA.eigh(X)
            X_s = np.sqrt(X_s)
            X = U_x @ np.diag(X_s) @ np.transpose(U_x)

    
        ## Noise
        Y = np.random.randn(N,M)
        Y = Y/np.sqrt(N)
    
        W = np.random.randn(N,M)
        W = W/np.sqrt(N)


        ### Observation
        S = np.sqrt(SNR) * X @ Y + W
    
        ### SVD
        U_s, S_s , Vh_s = LA.svd(S)

        gS = g_CauchyK_num(S_s)

        ### Oracle Estimator for X & X^2
        e_hat_X_oracle = np.zeros(N)
        e_hat_XX_oracle = np.zeros(N)
            
        XX = X @ X
        
        X_norm = LA.norm(X)**2
        XX_norm = LA.norm(XX)**2
        
        for k in range(N):
            e_hat_X_oracle[k] = np.transpose(U_s[:,k])@X@U_s[:,k]
            e_hat_XX_oracle[k] = np.transpose(U_s[:,k])@XX@U_s[:,k]
                
        X_hat_oracle = U_s@np.diag(e_hat_X_oracle)@np.transpose(U_s)
        XX_hat_oracle = U_s@np.diag(e_hat_XX_oracle)@np.transpose(U_s)
        
        E_X_oracle[i] = ( LA.norm(X-X_hat_oracle)**2 ) / X_norm
        E_XX_oracle[i] = ( LA.norm(XX-XX_hat_oracle)**2 ) / XX_norm



        #### RIE for X & X^2
        e_hat_X, e_hat_XX = Estimator(S_s, gX, gS, SNR, a)
        
        X_hat = U_s@np.diag(e_hat_X)@np.transpose(U_s)
        E_X_RIE[i] = ( LA.norm(X-X_hat)**2) / X_norm
        
        X_hat_sqXX = U_s@np.diag(np.sqrt(e_hat_XX))@np.transpose(U_s)
        E_X_sqXX[i] = ( LA.norm(X-X_hat_sqXX)**2 ) / X_norm
        
        XX_hat = U_s@np.diag(e_hat_XX)@np.transpose(U_s)
        E_X_RIE[i] = ( LA.norm(XX-XX_hat)**2 ) / XX_norm


    filename = 'X-'+prior+'_SNR='+str(SNR)+'_Oracle.npy'
    np.save( filename, E_X_oracle)
    
    filename = 'XX-'+prior+'_SNR='+str(SNR)+'_Oracle.npy'
    np.save( filename, E_XX_oracle)
    
    filename = 'X-'+prior+'_SNR='+str(SNR)+'_RIE.npy'
    np.save( filename, E_X_RIE)
    
    filename = 'X-'+prior+'_SNR='+str(SNR)+'_sqXX.npy'
    np.save( filename, E_X_sqXX)

    filename = 'XX-'+prior+'_SNR='+str(SNR)+'_RIE.npy'
    np.save( filename, E_XX_RIE)

#
if __name__ == "__main__":
    main()
    
