import numpy as np
import numpy.linalg as LA
from sympy import re, im, I, E, Symbol, sqrt
import argparse
from sympy.functions.elementary.miscellaneous import cbrt


def g_CauchyK_num(S):
    z = Symbol('z')
    ret = 0
    N = len(S)
    
    for j in range(N):

        ret += 1/(z + S[j] - I*np.sqrt(1/(2*N)) )
        ret += 1/(z - S[j]- I*np.sqrt(1/(2*N)) )
    
    return ret/(2*N)

def Estimator(S_s, gS, SNR, a, c):
    
    N = len(S_s)
    
    output = np.zeros(N)
    
    z = Symbol('z')
    
    for i in range(N):
        
        #### optimal singularvalue for Y
        zz = S_s[i] -  I*np.sqrt(1/(2*N))
        
        gS_eval = gS.subs(z,zz).evalf()
        
        q4 = -3 *c +   ( 3**(2/3) * (a * (-4+c**2) * zz - 2*a* (gS_eval**2)* zz+ 2* gS_eval *(-1+a+a* zz**2)) )/ \
        ( a*zz * cbrt( (9 * c * gS_eval * (-1+a * (1-gS_eval* zz+zz**2)))/(a * zz) \
                     + (1/3)*sqrt( (729 * (c**2) * (gS_eval**2) * (-1+a * (1-gS_eval* zz+zz**2))**2)/( (a**2) * zz**2) \
                               +(-3 * c**2+6 * (2+gS_eval**2-(gS_eval* (-1+a+a* zz**2))/(a * zz)))**3  ) )) \
                + cbrt( (27 * c * gS_eval * (-1+a * (1-gS_eval* zz+zz**2)))/(a * zz) \
                     + sqrt( (729 * (c**2) * (gS_eval**2) * (-1+a * (1-gS_eval* zz+zz**2))**2)/( (a**2) * zz**2) \
                               +(-3 * c**2+6 * (2+gS_eval**2-(gS_eval* (-1+a+a* zz**2))/(a * zz)))**3  ) )
        q4 = q4/6
        
        q4 = q4.evalf()
        
        output[i] =  (im(q4)/(np.sqrt(SNR) * im(gS_eval))).evalf()
    
    return output


def main():
    
    z = Symbol('z')
    p = argparse.ArgumentParser()

    p.add_argument('-a', type=float)
    p.add_argument('-s', type=float)
    p.add_argument('-p', type=str)
    
    args = p.parse_args()
    
    a = args.a
    prior = args.p
    SNR = args.s

    N = 2000
    M = int(N/a)
    
    Ex = 10
    
        
    E_oracle = np.zeros(Ex)
    E_RIE = np.zeros(Ex)

    for i in range(Ex):
        
        if prior == 'Gaussian':
            Y = np.random.randn(N,M)
            Y = Y/np.sqrt(N)
                                
        elif prior == 'Uniform':
            
            G = np.random.randn(N,M)
            G = G/np.sqrt(N)
            U_y, _, Vh_y = LA.svd(G)
            s_y = 2 * np.random.rand(N) + 1
            S_y =  np.hstack((np.diag(s_y),np.zeros((N,M-N))))
            Y = U_y @ S_y @ Vh_y

        ## Noise
        c = 3
        X = np.triu(np.random.normal(0, 1, (N,N)))
        X = X + np.transpose(X) + np.diag(np.random.normal(loc=0, scale=np.sqrt(2), size=(N)))
        X = X/np.sqrt(N)
        X = X + c*np.eye(N)
    
        W = np.random.randn(N,M)
        W = W/np.sqrt(N)


        ### Observation
        S = np.sqrt(SNR) * X @ Y + W
    
        ### SVD
        U_s, S_s , Vh_s = LA.svd(S)

        gS = g_CauchyK_num(S_s)

        ### Oracle Estimator for X & X^2
        e_hat_oracle = np.zeros(N)
        
        Y_norm = LA.norm(Y)**2
        
        for k in range(N):
            e_hat_oracle[k] = np.transpose(U_s[:,k])@Y@Vh_s[k,:]
                
        SV_oracle = np.hstack((np.diag(e_hat_oracle),np.zeros((N,M-N))))
        
        Y_hat_oracle = U_s@SV_oracle@ Vh_s
        
        E_oracle[i] = ( LA.norm(Y-Y_hat_oracle)**2 ) / Y_norm



        #### RIE for X & X^2
        e_hat = Estimator(S_s, gS, SNR, a, c)
        
        SV_RIE = np.hstack((np.diag(e_hat),np.zeros((N,M-N))))
        Y_hat = U_s@SV_RIE@ Vh_s

        E_RIE[i] = ( LA.norm(Y-Y_hat)**2) / Y_norm


    filename = 'Y-'+prior+'_SNR='+str(SNR)+'_Oracle.npy'
    np.save( filename, E_oracle)
    
    
    filename = 'Y-'+prior+'_SNR='+str(SNR)+'_RIE.npy'
    np.save( filename, E_RIE)
    

#
if __name__ == "__main__":
    main()
    
