function demo_L0PCA
% min_X  0.5/m*||A-A*X*X'||_F^2 + rho ( ||X||_1 - ||X||_{topk} ), s.t. X'X = I
% min_X  0.5/m*||A-A*X*X'||_F^2 + rho ( ||X||_1 - ||X||_{topk} ), s.t. X'X = I
% min_X - 0.5/m*mdot(A,A*X*X') + const + rho ( ||X||_1 - ||X||_{topk} ), s.t. X'X = I
% const = 0.5/m*fnorm(A)^2

clc;close all;clear all;
addpath('util','data','solver');
rand('seed',1);
randn('seed',1);



time_limits = [30 60 120 120 40 80 30 60];
rs          = [50 50 70 70 30 30 30 30];
i_data      = [11 12 21 22 31 32 41 42];


time_limits = [30 30 30 30 30 30 30 30];
rs          = [10 10 10 10 10 10 10 10];
i_data      = [11 12 21 22 31 32 41 42];

for iii = 1:8
    
    iwhich = i_data(iii);
    r = rs(iii);
    timeLimit = time_limits(iii);
    
    A = GetDataSetL0SPCA(iwhich);

    [m,n] = size(A);

    NormA = MatrixSpectralNorm(A);
    X0 = orth(randn(n,r));
    const = 0.5*fnorm(A)^2/m;
    rho1 = 10;
    rho2 = 10;
    k = round(n*r*0.5);
    beta0 = 100;
    max_iter = 1e100;
 

    timeIntervel = 2; % Record the objective value in every timeIntervel second
%     [X_manpg] = manpg_orth_sparse(0.5/m*A'*A,const,n,r,rho1,X0,1e-5,10,timeLimit);
    
    
    [X1,fobjs1,ts1] = L0SPCA_OADMM_EP(X0,A,NormA,const,rho1,rho2,k,max_iter,beta0,timeIntervel,timeLimit);
    [X2,fobjs2,ts2] = L0SPCA_OADMM_RR(X0,A,NormA,const,rho1,rho2,k,max_iter,beta0,timeIntervel,timeLimit);
    [X3,fobjs3,ts3] = L0SPCA_SubGrad(X0,A,const,rho1,rho2,k,max_iter,timeIntervel,timeLimit);
    [X4,fobjs4,ts4] = L0SPCA_Smoothing(X0,A,NormA,const,rho1,rho2,k,max_iter,beta0,timeIntervel,timeLimit);
    
    figure('color','w');
    myplot = @semilogy;
    pcolor = loadcolor;
    myplot(ts1,fobjs1,'--ks','LineWidth',5,'MarkerSize',3,'color', pcolor.red); hold on;
    myplot(ts2,fobjs2,'-.ms','LineWidth',5,'MarkerSize',3,'color',  pcolor.blue); hold on;
    myplot(ts3,fobjs3,':k*','LineWidth',5,'MarkerSize',3,'color', pcolor.purple); hold on;
    myplot(ts4,fobjs4,'-b+','LineWidth',5,'MarkerSize',3,'color', pcolor.green); hold on;
    
    
    % axis([0 max(ts3) min(fobj1) max_fobj_y])
    grid on;
    hleg=legend('OADMM-EP','OADMM-RR','SubGrad','PDM');
    set(hleg,'FontSize',20,'FontWeight','normal');
    set(hleg,'Fontname','times new Roman');
    set(hleg,'Location','NorthEast');
    set(gca,'Fontsize', 19);
    xlabel('Time (seconds)','FontSize',20)
    ylabel('Objective','FontSize',20,'interpreter','latex')
    all_fobj = [fobjs1;fobjs2;fobjs3;fobjs4];
    axis([0,timeLimit+1,min(all_fobj),max(all_fobj)]);
    fprintf('\n');
    set(gcf,'paperpositionmode','auto')
    print(sprintf('%s_%d.eps',mfilename,iwhich),'-depsc2','-loose');
    print(sprintf('%s_%d.png',mfilename,iwhich),'-dpng');
    print(sprintf('%s_%d.pdf',mfilename,iwhich),'-dpdf', '-r0');
end

function [X,fobjs,ts] = L0SPCA_OADMM_EP(X,A,NormA,const,rho1,rho2,data_k,max_iter,beta0,timeIntervel,timeLimit)
% min_X 0.5*||A-A*X*X'||_F^2 + rho (||X||_1 - ||X||_{topk} ), s.t. X'X = I, s.t. X'X = I  
% min_X  0.5*||A-A*X*X'||_F^2 + rho ( ||X||_1 - ||X||_{topk} ), s.t. X'X = I
% min_X  -0.5/m*mdot(A,A*X*X') + const + rho ( ||Y||_1 - ||X||_{topk} ), s.t. Y = X, Y'Y = I
% L(Y,X,Z) = -0.5/m*mdot(A,A*X*X') + const  + rho ( ||Y||_1 - ||X||_{topk} ) + <X-Y,Z> + 0.5 beta ||Y-X||_F^2, s.t. X'X = I
[m,n] = size(A);
Lsmooth = NormA^2/m;
beta0 = 100;
initt = clock;
last_rec_clock = initt;
HandleObj = @(X)L0SPCA_ComputeObj(X,A,rho1,rho2,data_k,const,m);

r = size(X,2);
Y = X;
Z = randn(n,r)*1;
% Y = randn(n,r);


sigma = 1.5;
xi = 0.9;
theta = 1.1;
alpha = 0.5*(theta-1) / ((theta+1)*(xi+2));


c0 = 1 + xi/(sigma^2) + xi / (2*sigma^2);
chi = sqrt(4*sigma^2*c0 / ((2-sigma)^2));

X_old = X;

ts = []; fobjs = [];
fobj = HandleObj(X);
fobjs = [fobjs;fobj];
ts = [ts;etime(clock,initt)];



for iter = 1:30000
    
    p = 1/3;
    beta = beta0*(1+xi*iter^p);
    mu = chi/beta;
 
    X_bar = X + alpha * (X - X_old); 

    % Update X
    % L(Y,X,Z) = -0.5/m*mdot(A,A*X*X') + const  + rho ( ||Y||_1 - ||X||_{topk} ) + <X-Y,Z> + 0.5 beta ||Y-X||_F^2, s.t. X'X = I
    h_grad = @(X) -A'*(A*X)/m + Z + beta*(X-Y);
    grad_X = h_grad(X_bar)- rho2*top_k_subgrad(X,data_k); % 
    Lipschit_constant = Lsmooth + beta;
    X_old = X;
    X = OrthProj(X_bar - grad_X/(theta*Lipschit_constant));

    % Update Y
    % -0.5/m*mdot(A,A*X*X') + const  + rho1 ||Y||_1 - rho2 ||X||_{topk} ) + <X-Y,Z> + 0.5 beta ||Y-X||_F^2, s.t. X'X = I
    % rho ||Y||_1 + <X-Y,Z> + 0.5 beta ||Y-X||_F^2
    % rho ||Y||_1 + 0.5 beta ||Y-(X+Z/beta)||_F^2
    B = X+Z/beta;
    beta_bar = beta/rho1;
    % min_Y  [||Y||_1]_{\mu} + 0.5 beta_bar ||Y - B||_F^2
    Y = smooth_prox_l1(B,mu,beta_bar);



    % Update Z
   diff = X-Y;
   Z = Z + sigma*beta*diff;

    
   e = norm(diff,'fro');
 
    cur_clock = clock;
      if(etime(cur_clock,last_rec_clock) > timeIntervel)
        fobj = HandleObj(X);
        ElasTime =  etime(cur_clock,initt);
        fobjs  = [fobjs;fobj];
        ts = [ts;ElasTime];
        last_rec_clock = cur_clock;
        if ElasTime > timeLimit
            break;
        end
    fprintf('iter:%d, fobj:%f, dist:%f, beta:%f\n',iter,fobj,e,beta);   

    
      end
 
end

function [X,fobjs,ts] = L0SPCA_OADMM_RR(X,A,NormA,const,rho1,rho2,data_k,max_iter,beta0,timeIntervel,timeLimit)
% min_X 0.5*||A-A*X*X'||_F^2 + rho (||X||_1 - ||X||_{topk} ), s.t. X'X = I, s.t. X'X = I  
% min_X  0.5*||A-A*X*X'||_F^2 + rho ( ||X||_1 - ||X||_{topk} ), s.t. X'X = I
% min_X  -0.5/m*mdot(A,A*X*X') + const + rho ( ||Y||_1 - ||X||_{topk} ), s.t. Y = X, Y'Y = I
% L(Y,X,Z) = -0.5/m*mdot(A,A*X*X') + const  + rho ( ||Y||_1 - ||X||_{topk} ) + <X-Y,Z> + 0.5 beta ||Y-X||_F^2, s.t. X'X = I
[m,n] = size(A);
Lsmooth = NormA^2/m;
beta0 = 10;
initt = clock;
last_rec_clock = initt;
HandleObj = @(X)L0SPCA_ComputeObj(X,A,rho1,rho2,data_k,const,m);

r = size(X,2);
Y = X;
Z = randn(n,r)*1;
Y = orth(randn(n,r));

sigma = 1.5;
xi = 0.9;
theta = 1.1;
alpha = 0.5*(theta-1) / ((theta+1)*(xi+2));
alpha = 0;

c0 = 1 + xi/(sigma^2) + xi / (2*sigma^2);
chi = sqrt(4*sigma^2*c0 / ((2-sigma)^2));

chi  = 1;

XP = X;
dtX = zeros(n,r);

ts = []; fobjs = [];
fobj = HandleObj(X);
fobjs = [fobjs;fobj];
ts = [ts;etime(clock,initt)];


rho = 1/2;
delta = 0.25;gamma = 0.1;
for iter = 1:30000
    
    if(~mod(iter,50)),
        beta0 = min(10000,beta0 *2);
    end
    p = 1/3;
    beta = beta0*(1+xi*iter^p);
    beta = min(beta,10000);
    mu = chi/beta;
%       beta = min(1000,beta);
    fun = @(X) ComputeObj(X,A,m,data_k,beta,Y,rho2,Z);
%     tau = 1/beta;

%     [X] = myQR2(XP - tau*dtX);
    
    [F, G] = fun(X);
    dtX = G - X*G'*X;
    nrmG  = norm(dtX, 'fro');
     tau = 0.01;
%     if(iter==1)
%         tau = 0.01;
%     else
%         S = X - XP;
%         Y = dtX - dtXP;
%         SY = mdot(S,Y);
%         if mod(iter,2)==0; tau = mdot(S,S) / SY;
%         else tau  = SY / mdot(Y,Y); end
%         tau = max(min(tau,1), 0.1);
%     end

    XP = X; FP = F;  dtXP = dtX;
    for jj = 1:1000
        eta = beta0/beta * tau*gamma^jj;
%          [X] = XP - eta*dtX;
         Delta = - eta*dtX;
         X = retr(XP,Delta);
%         [X] = orth(XP - eta*dtX);
        [F,G] = fun(X);
        if abs(eta)<1e-10 || F <= FP-eta*1e-12*nrmG^2;
            break;
        end
 
    end

%     his = [his;F];
%     fprintf('%4d  %3.2e\n',  itr, F);

%      [F, G] = fun(X);
%      fprintf('G: %.e %.e \n',norm(G - X*G'*X,'fro'), norm(X'*X-eye(r)) );
%     [X] = myQR2(XP - tau*dtX);;
   
%    [FF,G]  = fun(X);

    % Update Y
    % -0.5/m*mdot(A,A*X*X') + const  + rho1 ||Y||_1 - rho2||X||_{topk}  + <X-Y,Z> + 0.5 beta ||Y-X||_F^2, s.t. X'X = I
    % rho ||Y||_1 + <X-Y,Z> + 0.5 beta ||Y-X||_F^2
    % rho ||Y||_1 + 0.5 beta ||Y-(X+Z/beta)||_F^2
    B = X+Z/beta;
    beta_bar = beta/rho1;
    % min_Y  [||Y||_1]_{\mu} + 0.5 beta_bar ||Y - B||_F^2
    Y = smooth_prox_l1(B,mu,beta_bar);



    % Update Z
   diff = X-Y;
   Z = Z + sigma* beta*diff;

    
   e = norm(diff,'fro');
 
    cur_clock = clock;
       if(etime(cur_clock,last_rec_clock) > timeIntervel)
        fobj = HandleObj(X);
        ElasTime =  etime(cur_clock,initt);
        fobjs  = [fobjs;fobj];
        ts = [ts;ElasTime];
        last_rec_clock = cur_clock;
        if ElasTime > timeLimit
             break;
        end

         fprintf('iter:%d, fobj:%f, dist:%f, beta:%f\n',iter,fobj,e,beta);

   
       end
end


function X = retr(X,Delta)
%  [n,r] = size(X);
%   W = @(xi) (eye(n) - 0.5*X*X')*xi *X' - X*xi'*(eye(n) - 0.5*X*X');
%  X = inv(eye(n) - 0.5*W(Delta))*(eye(n) + 0.5*W(Delta))*X;
% X=myQR(X+Delta,size(X,2));
%  [X] = myQR2(X+Delta);

% X = X+Delta;
% B_power = X'*X;
% [U,Sigma,V] = svd(B_power);
% SIGMA =diag(Sigma);
% X = X*(U*diag(sqrt(1./SIGMA))*V');

% % polar retraction 
% X = X+Delta;
% [U,Sigma] = eig(X'*X);
% SIGMA =diag(Sigma);
% X = X*(U*diag(sqrt(1./SIGMA))*U');
       
X = X+Delta;
X = X*(Delta'*Delta + eye(size(Delta,2)))^(-0.5);
    
% X = X+Delta;
% X = OrthProj(X);

    
function [F,G] = ComputeObj(X,A,m,data_k,beta,Y,rho2,Z)

F = -0.5/m*mdot(A,A*X*X') + mdot(X,Z) + 0.5*beta*norm(Y-X,'fro')^2 - rho2* tksum(X,data_k);
G = -A'*(A*X)/m + Z + beta*(X-Y)  - rho2*top_k_subgrad(X,data_k);



function Y = smooth_prox_l1(B,mu,beta)
% min_Y  [||Y||_1]_{\mu} + 0.5 beta ||Y - B||_F^2
y_bar = prox_l1(B,(1+mu*beta)/beta);
Y =  ( y_bar/mu + beta*B) / (1/mu + beta)  ;
 
% fun = @(Y1)l1_norm_mu(Y1,mu) + 0.5*beta*norm(Y1-B,'fro')^2;
% fun(Y)
% Y_test = randn(size(Y));
% fun(Y) - fun(Y_test)
% dd

function f = l1_norm_mu(B,mu)
% Y_optimal = arg min_Y ||Y||_1 + 0.5/mu ||Y-B||_F^2
% Y_optimal = arg min_Y mu ||Y||_1 + 0.5 ||Y-B||_F^2
Y_optimal = prox_l1(B,mu);
f = norm(Y_optimal(:),1) + 0.5/mu* norm(Y_optimal-B,'fro')^2;


function [Q,R] = myQR2(X)
[Q, R] = qr(X,0);

function [Q, RR] = myQR(XX,k)
[Q, RR] = qr(XX, 0);
diagRR = sign(diag(RR)); ndr = diagRR < 0;
if nnz(ndr) > 0
%     Q = Q*spdiags(diagRR,0,k,k);
    Q(:,ndr) = Q(:,ndr)*(-1);
end


function [x] = prox_l1(a,lambda)
% It solves the following OP:
% min_{x} 0.5||x-a||^2 + lambda * sum(abs(x))
x = sign(a).*max(0,abs(a)-lambda);





function [theta1,alpha1,sigma,theta2,alpha2,xi] = FindSuitableParametersOADMMEP()
eps1 = 0.01; eps2 = 0.01; eps3 = 0.001;

% Choose parameters the the n-1 block
eps0 = 1e-5;
% theta - 1 - (2+eps1) alpha theta > eps0
% theta - 1 - eps0  > (2+eps1) alpha theta
choose_alpha1 = @(theta)(theta-1-eps0) / ((2+eps1)*(theta));
theta1 = 1.05;
alpha1 = choose_alpha1(theta1);

% adjust parameters for the last block
choose_alpha2 = @(theta)(theta-1)/((2+eps1)*(1+theta)) * (1-eps1);
theta2 = 1.001;
alpha2 = choose_alpha2(theta2);
gamma = 0.5*( theta2 - 1 - (2+eps1)*alpha2*theta2 );
sigma = 1.5;
sigma1 = sigma / ( (1 - abs(1-sigma))^2);

delta = 1 + eps2;
tau = (1+eps1)*alpha2^2;
chi = (1+eps3)*theta2;
gamma_n_primal = gamma*(1-eps3);


if(8*sigma1*delta*(1+eps3)*( (chi-1)^2  + tau * chi ) - gamma*(1-eps3)>0)
    error('NO!');
end
 
xi = min(eps1,eps2*sigma);



