close all
clear all

n = 100;
density = 0.1;

SNR = 0.48;
lambda = 0.007;
%c_noise = 0.015; %100-0.015
T = 2000;
eta = 1;%1/(2*sqrt(lambda^2+1));
numExp = 10;
r = 1;
tauCoeff = 0.7;

avg_signalToNoiseRatio = 0;
avg_XinitialNorm1 = 0;
avg_errorInitial_LR = 0;
avg_numIterCondHolds_Z = 0;
avg_firstIterConditionHolds_Z = 0;
avg_numIterCondHolds_X = 0;
avg_firstIterConditionHolds_X = 0;
avg_Xnorm1_ergodic = 0;
avg_error_dualGapMin = 0;
avg_rank_ergodic = 0;
avg_Xnorm1_dualGapMin = 0;
avg_error_ergodic = 0;
avg_gap_ergodic = 0;
avg_dual_gap_ergodic = 0;
avg_dual_gap_dualGapMin = 0;
avg_c_noise = 0;
avg_gap_minDualGap = 0;


for i = 1:numExp
   
    
    Z_0 = sprand(n,r,density);
    Z_0(Z_0~=0) = randi([1,10],size(nonzeros(Z_0)));
    norm_Z0 = norm(Z_0,'fro');
    Z_0 = Z_0/norm_Z0; 
    N = randn(n,n)+0.5;
%     c_noise = sqrt(4/SNR)*(1/norm((N + N'),'fro'));
    c_noise = sqrt(4/SNR)*(norm((Z_0*Z_0'),'fro')/norm((N + N'),'fro'));
    avg_c_noise = avg_c_noise + c_noise/numExp;
    M = Z_0*Z_0' + (c_noise/2)*(N + N'); 
    P = Z_0*Z_0';
    rSVD = r;
    tau = tauCoeff*trace(Z_0*Z_0');
    
    signalToNoiseRatio = norm((Z_0*Z_0'),'fro')^2/norm(c_noise/2.*(N + N'),'fro')^2;
    avg_signalToNoiseRatio = avg_signalToNoiseRatio + signalToNoiseRatio/numExp;
   
    
    % EXTRA-GRADIENT
    
    [U,S] = eigs(M,rSVD,'largestreal'); %low rank approximation
    X = U*diag(projsplx(diag(S),tau))*U';
    X_dualGapMin = X;
    X_initial = X;
    
    Y = sign(X);
    Y_dualGapMin = Y;
    
    
    [U2,S2] = eigs(X-M+lambda*Y,1,'smallestreal');
    dual_gapX_dualGapMin = trace((X-tau*(U2(:,1)*U2(:,1)'))'*(X-M+lambda*Y));
    dual_gapY_dualGapMin = trace((Y-sign(lambda*X))'*(lambda*X));
    minDualGap = dual_gapX_dualGapMin - dual_gapY_dualGapMin;
    
    errorInitial_LR = norm((trace(Z_0*Z_0')/tau)*X-(Z_0*Z_0'),'fro')^2 / norm(Z_0*Z_0','fro')^2;
    avg_errorInitial_LR = avg_errorInitial_LR + errorInitial_LR ./ numExp;
    
    Xinitial_norm1 = sum(sum(abs(X)));
    avg_XinitialNorm1 = avg_XinitialNorm1 + Xinitial_norm1/numExp;
    
    checkCondition_Z = zeros(T,1);
    checkCondition_X = zeros(T,1); 
    Z_ergodic = zeros(n);
    W_ergodic = zeros(n);
    
    
    for t = 1:T
        
        grad_X = X-M+lambda*Y;
        [U,S] = eigs(X-eta*grad_X,rSVD+1,'largestreal');
        eigens = projsplx(diag(S(1:rSVD,1:rSVD)),tau);
        Z = U(:,1:rSVD)*diag(eigens)*U(:,1:rSVD)';
        
        if sum(diag(S(1:rSVD,1:rSVD))) >= tau+rSVD*S(rSVD+1,rSVD+1)
            checkCondition_Z(t) = 1;
        end
        
        Z_ergodic = Z_ergodic + Z/T;
        
        grad_Y = lambda*X;
        W = Y + eta*grad_Y;
        W(W>1) = 1;     
        W(W<-1) = -1;
        
        W_ergodic = W_ergodic + W/T;
        
        grad_Z = Z-M+lambda*W;
        [U,S] = eigs(X-eta*grad_Z,rSVD+1,'largestreal');
        eigens = projsplx(diag(S(1:rSVD,1:rSVD)),tau);
        X = U(:,1:rSVD)*diag(eigens)*U(:,1:rSVD)';
        
        if sum(diag(S(1:rSVD,1:rSVD))) >= tau+rSVD*S(rSVD+1,rSVD+1)
            checkCondition_X(t) = 1;
        end
        
        grad_W = lambda*Z;
        Y = Y + eta*grad_W;
        Y(Y>1) = 1;     
        Y(Y<-1) = -1;
        
        [U2,S2] = eigs(X-M+lambda*Y,1,'smallestreal');
        dual_gapX_dualGapMin = trace((X-tau*(U2(:,1)*U2(:,1)'))'*(X-M+lambda*Y));
        dual_gapY_dualGapMin = trace((Y-sign(lambda*X))'*(lambda*X));
        dual_gap_dualGapMin = dual_gapX_dualGapMin - dual_gapY_dualGapMin;
        if dual_gap_dualGapMin < minDualGap
            minDualGap = dual_gap_dualGapMin;
            X_dualGapMin = X;
            Y_dualGapMin = Y;
        end
        
        [U2,S2] = eigs(Z-M+lambda*W,1,'smallestreal');
        dual_gapX_dualGapMin = trace((Z-tau*(U2(:,1)*U2(:,1)'))'*(Z-M+lambda*W));
        dual_gapY_dualGapMin = trace((W-sign(lambda*Z))'*(lambda*Z));
        dual_gap_dualGapMin = dual_gapX_dualGapMin - dual_gapY_dualGapMin;
        if dual_gap_dualGapMin < minDualGap
            minDualGap = dual_gap_dualGapMin;
            X_dualGapMin = Z;
            Y_dualGapMin = W;
        end
        
    end
    
    numIterCondHolds_Z = nnz(checkCondition_Z);
    avg_numIterCondHolds_Z = avg_numIterCondHolds_Z + numIterCondHolds_Z ./ numExp;
    firstIterConditionHolds_Z = find(checkCondition_Z,1,'first');
    avg_firstIterConditionHolds_Z = avg_firstIterConditionHolds_Z +firstIterConditionHolds_Z/numExp;
    
    numIterCondHolds_X = nnz(checkCondition_X);
    avg_numIterCondHolds_X = avg_numIterCondHolds_X + numIterCondHolds_X ./ numExp;
    firstIterConditionHolds_X = find(checkCondition_X,1,'first');
    avg_firstIterConditionHolds_X = avg_firstIterConditionHolds_X +firstIterConditionHolds_X/numExp;
    
    X_norm1_ergodic = sum(sum(abs(Z_ergodic)));
    avg_Xnorm1_ergodic = avg_Xnorm1_ergodic + X_norm1_ergodic/numExp;
    
    X_norm1_dualGapMin = sum(sum(abs(X_dualGapMin)));
    avg_Xnorm1_dualGapMin = avg_Xnorm1_dualGapMin + X_norm1_dualGapMin/numExp;
    
    error_ergodic = norm((trace(Z_0*Z_0')/tau)*Z_ergodic-(Z_0*Z_0'),'fro')^2 / norm(Z_0*Z_0','fro')^2;
    avg_error_ergodic = avg_error_ergodic + error_ergodic ./ numExp;
    
    error_dualGapMin = norm((trace(Z_0*Z_0')/tau)*X_dualGapMin-(Z_0*Z_0'),'fro')^2 / norm(Z_0*Z_0','fro')^2;
    avg_error_dualGapMin = avg_error_dualGapMin + error_dualGapMin ./ numExp;
    
    rank_ergodic = rank(Z_ergodic);
    avg_rank_ergodic = avg_rank_ergodic + rank_ergodic ./ numExp;
    
    [U1,S1] = eigs(-(Z_ergodic-M+lambda*W_ergodic),rSVD+1,'largestreal');
    gap_ergodic = S1(rSVD,rSVD) - S1(rSVD+1,rSVD+1);
    avg_gap_ergodic = avg_gap_ergodic + gap_ergodic ./ numExp;
    
    
    dual_gapX_ergodic = trace((Z_ergodic-tau*(U1(:,1)*U1(:,1)'))'*(Z_ergodic-M+lambda*W_ergodic));
    dual_gapY_ergodic = trace((W_ergodic-sign(lambda*Z_ergodic))'*(lambda*Z_ergodic));
    dual_gap_ergodic = dual_gapX_ergodic - dual_gapY_ergodic;
    avg_dual_gap_ergodic = avg_dual_gap_ergodic + dual_gap_ergodic/numExp;
    
    [U2,S2] = eigs(-(X_dualGapMin-M+lambda*Y_dualGapMin),rSVD+1,'largestreal');
    gap_minDualGap = S2(rSVD,rSVD) - S2(rSVD+1,rSVD+1);
    avg_gap_minDualGap = avg_gap_minDualGap + gap_minDualGap ./ numExp;
    
    dual_gapX_dualGapMin = trace((X_dualGapMin-tau*(U2(:,1)*U2(:,1)'))'*(X_dualGapMin-M+lambda*Y_dualGapMin));
    dual_gapY_dualGapMin = trace((Y_dualGapMin-sign(lambda*X_dualGapMin))'*(lambda*X_dualGapMin));
    dual_gap_dualGapMin = dual_gapX_dualGapMin - dual_gapY_dualGapMin;
    avg_dual_gap_dualGapMin = avg_dual_gap_dualGapMin + dual_gap_dualGapMin/numExp;
    
end

    
avg_signalToNoiseRatio
avg_c_noise
avg_XinitialNorm1
avg_errorInitial_LR
avg_numIterCondHolds_Z
avg_firstIterConditionHolds_Z
avg_numIterCondHolds_X
avg_firstIterConditionHolds_X
%avg_rank_ergodic
avg_Xnorm1_ergodic
avg_error_ergodic
avg_gap_ergodic
avg_dual_gap_ergodic
avg_Xnorm1_dualGapMin
avg_error_dualGapMin
avg_gap_minDualGap
avg_dual_gap_dualGapMin