close all
clear all

n = 600;
m = n;
lambda = 2;
T = 2000;
eta = 1/(2*lambda);
numExp = 10;
r = 1; 
tau = 1;
SNR = 0.027; 

avg_signalToNoiseRatio = 0;
avg_errorInitial_LR = 0;
avg_numIterCondHolds_Z = 0;
avg_firstIterConditionHolds_Z = 0;
avg_numIterCondHolds_X = 0;
avg_firstIterConditionHolds_X = 0;
avg_rank_ergodic = 0;
avg_error_ergodic = 0;
avg_gap_ergodic = 0;
avg_dual_gap_ergodic = 0;
avg_minDualGap = 0;
avg_error_dualGapMin = 0;
avg_gap_dualGapMin = 0;
avg_dual_gap_dualGapMin = 0;
avg_c_noise = 0;
avg_norm_operator = 0;

for i = 1:numExp
    
    
    z0 = randn(n,1);
    z0 = z0/norm(z0);
    P = (z0*z0');
    
    N = randn(n,n);
    c_noise = sqrt(4/SNR)*(1/norm((N + N'),'fro'));
    avg_c_noise = avg_c_noise + c_noise/numExp;
    M = z0*z0' + (c_noise/2)*(N + N'); 
    rSVD = r;
     
    signalToNoiseRatio = norm((z0*z0'),'fro')^2/norm((c_noise/2)*(N + N'),'fro')^2;
    avg_signalToNoiseRatio = avg_signalToNoiseRatio + signalToNoiseRatio/numExp;
   
    A = cell(m,1);
    b = zeros(m,1);
    for k = 1:m
        v = randn(n,1);
        v = v/norm(v);
        A{k} = v*v';
        b(k) = trace((v*v')*(z0*z0'));
    end
    
    
    % EXTRA-GRADIENT
    
    [U,S] = eigs(M,1,'largestreal'); %low rank approximation
    X = tau*(U*U');
    X_dualGapMin = X;
    X_initial = X;
    
    operator = operatorA_X_b(A,X,b,m);
    if norm(operator) > 1
        y = operator/norm(operator);
    else
        y = operator;
    end
    y_dualGapMin = y;
    
    grad_X = -M+lambda*grad_X_operatorA_X_b(A,y,n,m);
    [Ux,Sx] = eigs(grad_X,1,'smallestreal');
    grad_y = lambda*operator;
    dual_gapX_dualGapMin = trace((X-tau*(Ux*Ux'))'*grad_X);
    dual_gapY_dualGapMin = (y-(grad_y/norm(grad_y)))'*grad_y;
    minDualGap = dual_gapX_dualGapMin - dual_gapY_dualGapMin;
    
    errorInitial_LR = norm((trace(z0*z0')/tau)*X_initial-(z0*z0'),'fro')^2 / norm(z0*z0','fro')^2;
    avg_errorInitial_LR = avg_errorInitial_LR + errorInitial_LR ./ numExp;
    
   checkCondition_Z = zeros(T,1);
    checkCondition_X = zeros(T,1); 
    Z_ergodic = zeros(n);
    w_ergodic = zeros(m,1);
    
    for t = 1:T
        
        grad_X = -M+lambda*grad_X_operatorA_X_b(A,y,n,m);
        [U,S] = eigs(X-eta*grad_X,rSVD+1,'largestreal');
        eigens = projsplx(diag(S(1:rSVD,1:rSVD)),tau);
        Z = U(:,1:rSVD)*diag(eigens)*U(:,1:rSVD)';
        
        if sum(diag(S(1:rSVD,1:rSVD))) >= tau+rSVD*S(rSVD+1,rSVD+1)
            checkCondition_Z(t) = 1;
        end
        
        Z_ergodic = Z_ergodic + Z/T;
        
        grad_y = lambda*operatorA_X_b(A,X,b,m);
        w = y + eta*grad_y;
        if norm(w) > 1
            w = w/norm(w);
        end
        
        w_ergodic = w_ergodic + w/T;
        
        grad_Z = -M+lambda*grad_X_operatorA_X_b(A,w,n,m);
        [U,S] = eigs(X-eta*grad_Z,rSVD+1,'largestreal');
        eigens = projsplx(diag(S(1:rSVD,1:rSVD)),tau);
        X = U(:,1:rSVD)*diag(eigens)*U(:,1:rSVD)';
        
        if sum(diag(S(1:rSVD,1:rSVD))) >= tau+rSVD*S(rSVD+1,rSVD+1)
            checkCondition_X(t) = 1;
        end
        
        grad_w = lambda*operatorA_X_b(A,Z,b,m);
        y = y + eta*grad_w;
        if norm(y) > 1
            y = y/norm(y);
        end
        
        [U2,S2] = eigs(grad_X,1,'smallestreal');
        dual_gapX_dualGapMin = trace((X-tau*(U2*U2'))'*grad_X);
        dual_gapY_dualGapMin = (y-(grad_y/norm(grad_y)))'*grad_y;
        dual_gap_dualGapMin = dual_gapX_dualGapMin - dual_gapY_dualGapMin;
        if dual_gap_dualGapMin < minDualGap
            minDualGap = dual_gap_dualGapMin;
            X_dualGapMin = X;
            y_dualGapMin = y;
        end
        
        [U2,S2] = eigs(grad_Z,1,'smallestreal');
        dual_gapX_dualGapMin = trace((Z-tau*(U2*U2'))'*grad_Z);
        dual_gapY_dualGapMin = (w-(grad_w/norm(grad_w)))'*grad_w;
        dual_gap_dualGapMin = dual_gapX_dualGapMin - dual_gapY_dualGapMin;
        if dual_gap_dualGapMin < minDualGap
            minDualGap = dual_gap_dualGapMin;
            X_dualGapMin = Z;
            y_dualGapMin = w;
        end
        

        
    end
    
    
    
    numIterCondHolds_Z = nnz(checkCondition_Z);
    avg_numIterCondHolds_Z = avg_numIterCondHolds_Z + numIterCondHolds_Z ./ numExp;
    firstIterConditionHolds_Z = find(checkCondition_Z,1,'first');
    avg_firstIterConditionHolds_Z = avg_firstIterConditionHolds_Z +firstIterConditionHolds_Z/numExp;
    
    numIterCondHolds_X = nnz(checkCondition_X);
    avg_numIterCondHolds_X = avg_numIterCondHolds_X + numIterCondHolds_X ./ numExp;
    firstIterConditionHolds_X = find(checkCondition_X,1,'first');
    avg_firstIterConditionHolds_X = avg_firstIterConditionHolds_X +firstIterConditionHolds_X/numExp;
    
    error_ergodic = norm((trace(z0*z0')/tau)*Z_ergodic-(z0*z0'),'fro')^2 / norm(z0*z0','fro')^2;
    avg_error_ergodic = avg_error_ergodic + error_ergodic ./ numExp;
    
    error_dualGapMin = norm((trace(z0*z0')/tau)*X_dualGapMin-(z0*z0'),'fro')^2 / norm(z0*z0','fro')^2;
    avg_error_dualGapMin = avg_error_dualGapMin + error_dualGapMin ./ numExp;
    
    rank_ergodic = rank(Z_ergodic);
    avg_rank_ergodic = avg_rank_ergodic + rank_ergodic ./ numExp;
    
    grad_Z_ergodic = -M+lambda*grad_X_operatorA_X_b(A,w_ergodic,n,m);
    [U1,S1] = eigs(-grad_Z_ergodic,rSVD+1,'largestreal');
    gap_ergodic = S1(rSVD,rSVD) - S1(rSVD+1,rSVD+1);
    avg_gap_ergodic = avg_gap_ergodic + gap_ergodic ./ numExp;
    
    grad_w_ergodic = lambda*operatorA_X_b(A,Z_ergodic,b,m);
    dual_gapX_ergodic = trace((Z_ergodic-tau*(U1(:,1)*U1(:,1)'))'*grad_Z_ergodic);
    dual_gapY_ergodic = (w_ergodic-(grad_w_ergodic/norm(grad_w_ergodic)))'*grad_w_ergodic;
    dual_gap_ergodic = dual_gapX_ergodic - dual_gapY_ergodic;
    avg_dual_gap_ergodic = avg_dual_gap_ergodic + dual_gap_ergodic/numExp;
    
    grad_X_dualGapMin = -M+lambda*grad_X_operatorA_X_b(A,y_dualGapMin,n,m);
    [U2,S2] = eigs(-grad_X_dualGapMin,rSVD+1,'largestreal');
    gap_minDualGap = S2(rSVD,rSVD) - S2(rSVD+1,rSVD+1);
    avg_gap_dualGapMin = avg_gap_dualGapMin + gap_minDualGap ./ numExp;
    
    grad_y_dualGapMin = lambda*operatorA_X_b(A,X_dualGapMin,b,m);
    dual_gapX_dualGapMin = trace((X_dualGapMin-tau*(U2(:,1)*U2(:,1)'))'*grad_X_dualGapMin);
    dual_gapY_dualGapMin = (y_dualGapMin-(grad_y_dualGapMin/norm(grad_y_dualGapMin)))'*grad_y_dualGapMin;
    dual_gap_dualGapMin = dual_gapX_dualGapMin - dual_gapY_dualGapMin;
    avg_dual_gap_dualGapMin = avg_dual_gap_dualGapMin + dual_gap_dualGapMin/numExp;
    
    avg_norm_operator = avg_norm_operator + norm(operatorA_X_b(A,X_dualGapMin,b,m))/numExp;
    
end

    
avg_signalToNoiseRatio
avg_c_noise
avg_errorInitial_LR
avg_numIterCondHolds_Z
avg_firstIterConditionHolds_Z
avg_numIterCondHolds_X
avg_firstIterConditionHolds_X
avg_error_dualGapMin
avg_gap_dualGapMin
avg_dual_gap_dualGapMin
avg_norm_operator
