function test_sparse_phase_retrieval_capped_l1

clc;clear all;close all;
addpath('solver','util','data');
rand('seed',10); randn('seed',10);

% min_w sum_{i=1}^m 1/m ( (w'*x_i)^2 - b^2 )^2 + lambda || max(delta,|w|)||_1
% s.t. ||w||_{\infty} <= R

position_list1 = {'NorthEast','NorthEast','NorthEast','NorthEast','NorthEast','NorthEast','NorthEast','NorthEast'};
position_list2 = {'SouthWest','SouthWest','SouthWest','SouthWest','SouthWest','SouthWest','SouthWest','SouthWest'};

lambda_list    = [0.01;0.001];
timeLimit_list = [8 8 8 8 12 12 12 8];
data_list      = [1 2 3 4  5  6  7 8];

for i_lambda_list = 1:length(lambda_list)
    for i_data_list = 1:length(data_list)

        idat = data_list(i_data_list);
        timeLimit = timeLimit_list(i_data_list);
        if(i_lambda_list==1)
            position = position_list1{i_data_list};
        else
            position = position_list2{i_data_list};
        end
        
        [dat_X,dat_y] = GetData(idat);
        reg_para_lambda = lambda_list(i_lambda_list);
        reg_para_delta = 0.1;
        reg_para_R = 10;
        % dat_X: m x d
        % dat_y: m x 1
        
        % dat_m = 60;
        % index = randperm(size(dat_X,2),dat_m);
        % dat_X = dat_X(:,index);
        % dat_y = dat_y(:,index);
        
        [dat_m,dat_d]=size(dat_X);
        
        x0 = randn(dat_d,1);
        
        HandleObjSmoothMiniBatch = @(x,Batch)ObjectiveFunctionMiniBatch(x,dat_X,dat_y,Batch);
    
        
        timeIntervel = timeLimit/20;
        
        % HandleObjNonSmooth = @(X) lambda*( norm(X(:),1) -  norm(X(:),2) );
        % HandleObjNonSmooth = @(X) 0;
        HandleObjNonSmooth = @(x)Compute_h(x,reg_para_lambda,reg_para_delta,reg_para_R);
        HandleProx = @(a,mu)GenMapping(a,mu,reg_para_lambda,reg_para_delta,reg_para_R);
        HandleGProx = @(a,v)GenProximalMapping(a,v,reg_para_lambda,reg_para_delta,reg_para_R);
        
        Lipschitz = 10;
        
        tic;[x1,fobjs1,ts1] = ProximalSARAH(x0,HandleObjSmoothMiniBatch,HandleObjNonSmooth,HandleGProx,timeLimit,timeIntervel,Lipschitz,dat_m);toc;
        
        tic;[x2,fobjs2,ts2] = SpiderBoost(x0,HandleObjSmoothMiniBatch,HandleObjNonSmooth,HandleGProx,timeLimit,timeIntervel,Lipschitz,dat_m);toc;
        
        tic;[x3,fobjs3,ts3] = SpiderBoostM(x0,HandleObjSmoothMiniBatch,HandleObjNonSmooth,HandleGProx,timeLimit,timeIntervel,Lipschitz,dat_m);toc;
        
        tic;[x4,fobjs4,ts4] = SpiderSGP(x0,HandleObjSmoothMiniBatch,HandleObjNonSmooth,HandleGProx,timeLimit,timeIntervel,Lipschitz,dat_m,reg_para_lambda,reg_para_delta,reg_para_R);toc;
        
        tic;[x5,fobjs5,ts5] = AEPGSpider(x0,HandleObjSmoothMiniBatch,HandleObjNonSmooth,HandleGProx,timeLimit,timeIntervel,Lipschitz,dat_m,0.0);toc;
        
        tic;[x6,fobjs6,ts6] = AEPGSpider(x0,HandleObjSmoothMiniBatch,HandleObjNonSmooth,HandleGProx,timeLimit,timeIntervel,Lipschitz,dat_m,0.1);toc;
        
        tic;[x7,fobjs7,ts7] = AEPGSpider(x0,HandleObjSmoothMiniBatch,HandleObjNonSmooth,HandleGProx,timeLimit,timeIntervel,Lipschitz,dat_m,0.5);toc;
        
        tic;[x8,fobjs8,ts8] = AEPGSpider(x0,HandleObjSmoothMiniBatch,HandleObjNonSmooth,HandleGProx,timeLimit,timeIntervel,Lipschitz,dat_m,0.9);toc;
        
        [fobjs1,fobjs2,fobjs3,fobjs4,fobjs5,fobjs6,fobjs7,fobjs8] = normF(fobjs1,fobjs2,fobjs3,fobjs4,fobjs5,fobjs6,fobjs7,fobjs8);
        
        close all;
        pause(0.01)
        pcolor = loadcolor8;
        myplot = @semilogy;
        % : 
        myplot(ts1,fobjs1,'-'  ,'LineWidth',3,'MarkerSize',3,'color', pcolor.lime);  hold on;
        myplot(ts2,fobjs2,'-'   ,'LineWidth',3,'MarkerSize',3,'color', pcolor.teal);  hold on;
        myplot(ts3,fobjs3,'-'   ,'LineWidth',3,'MarkerSize',3,'color', pcolor.lavender); hold on;
        myplot(ts4,fobjs4,'-'  ,'LineWidth',3,'MarkerSize',3,'color', pcolor.skyblue);    hold on;
        myplot(ts5,fobjs5,'-'   ,'LineWidth',3,'MarkerSize',3,'color', pcolor.mustard);  hold on;
        myplot(ts6,fobjs6,'-'   ,'LineWidth',3,'MarkerSize',3,'color', pcolor.olive);    hold on;
        myplot(ts7,fobjs7,'-'   ,'LineWidth',3,'MarkerSize',3,'color', pcolor.turquoise);   hold on;
        myplot(ts8,fobjs8,'-'   ,'LineWidth',3,'MarkerSize',3,'color', pcolor.coral);     hold on;
        
        
        hleg= legend('ProxSARAH','SpiderBoost','SpiderBoost-M','SGP-SPIDER','AEPG-SPIDER(0)','AEPG-SPIDER(0.1)','AEPG-SPIDER(0.5)','AEPG-SPIDER(0.9)', 'Interpreter', 'latex');
        
        set(hleg,'FontSize',13,'FontWeight','normal');
        set(hleg,'Fontname','times new Roman');
        %         set(hleg,'Location','NorthEast');
        set(hleg,'Location',position);
        set(gca,'Fontsize', 12);
        set(hleg, 'Color', 'none');
        xlabel('Time (seconds)','FontSize',12)
        ylabel('Relative Objective','FontSize',12,'interpreter','latex')
        all_fobj = [fobjs1;fobjs2;fobjs3;fobjs4;fobjs5;fobjs6;fobjs7;fobjs8];
        axis([0,timeLimit+1,min(all_fobj),max(all_fobj)]);
        fprintf('\n');
        set(gcf,'paperpositionmode','auto')
        print(sprintf('%s_%d_%d.eps',mfilename,i_lambda_list,idat),'-depsc2','-loose');
        print(sprintf('%s_%d_%d.png',mfilename,i_lambda_list,idat),'-dpng');
        %     print(sprintf('%s_%d.pdf',mfilename,iwhich),'-dpdf', '-r0');
        
    end
    
    
end





function [f,grad] = ObjectiveFunctionMiniBatch(w,dat_X,dat_y,Batch)
% dat_X: m x d
% dat_y: m x 1
% min_{w} 1/m*sum_{i=1}^m (<x_i,w>^2 - y_i^2)^2


% Extract the batch data
X_batch = dat_X(Batch,:); % batch_size x d
Y_batch = dat_y(Batch);   % batch_size x 1
m_batch = size(X_batch,1);
dat_d = size(X_batch,2);

Xw = X_batch*w; %  batch_size x 1
Xw2 = Xw.*Xw;
y2 = Y_batch.*Y_batch; %  batch_size x 1

EEE = Xw2-y2;
f = 1/m_batch*norm(EEE)^2;
% grad = zeros(dat_d,1);
% for i=1:m_batch
%     xi = X_batch(i,:); % 1 x d
%     xi = xi';          % d x 1
%     grad = grad + 2*EEE(i)*2*mdot(w,xi)*xi;
% end
% grad = 1/m_batch*grad;

grad = (4 / m_batch) * (X_batch' * (EEE .* Xw)); % d x 1





function fobj = Compute_h(X,lambda,delta,R)
x = X(:);
x(x>R)=R;
x(x<-R)=-R;
absx = abs(x);
absx = max(delta,absx);
fobj = lambda * sum(absx);

function [fobj,grad] = CappedL1(x,lambda,delta)
fobj = lambda*sum(max(delta, abs(x)));
grad = zeros(size(x));
posIdx = (x > delta);
negIdx = (x < -delta);
grad(posIdx) = +lambda;
grad(negIdx) = -lambda;


function [x] = GenProximalMapping(a,v,lambda,delta,R)
% 0.5 ||x - a||_v^2 + h(x)
% h(x) = lambda*sum(max(delta,|x|)): -R <= x <= R

sizea = size(a);
a = a(:); v = v(:);

% case 1: x = -R
% case 2: x = R
% case 3: |x| <= delta,
%         It reduces to:  0.5 ||x - a||_v^2 + lambda*delta
%         v.*(x-a) = 0
%         x = a

% case 4: |x| > delta
%         It reduces to:  0.5 ||x - a||_v^2 + lambda*|x|
%         v.*(x - a) + lambda*sign(x) = 0
%         x   = a + (lambda*sign(x))./v


e = ones(size(a));
Proj = @(x)min(max(x,-R),R);
x1 = -R*e;
x2 = R*e;
x3 = Proj(a);
e_div_v = e./v;
x4 = Proj(a + lambda*e_div_v);
x5 = Proj(a - lambda*e_div_v);
handle_f = @(x)0.5.*v.*((x-a).*(x-a)) + lambda*max(delta,abs(x));
x = ComputeBestRoot5(handle_f,x1,x2,x3,x4,x5);
x = reshape(x,sizea);

function [x] = GenMapping(a,gamma,lambda,delta,R)
% 0.5 ||x - a||_2^2 + gamma*h(x)
% h(x) = lambda*sum(max(delta,|x|));

% x - a + gamma*lambda*
% case 1: x = -R
% case 2: x = R
% case 3: |x| <= delta,
%         It reduces to:  0.5 ||x - a||_2^2 + gamma*lambda*delta
%         x = a

% case 4: |x| > delta
%         It reduces to:  0.5 ||x - a||_2^2 + gamma*lambda*|x|
%         x - a + gamma*lambda*sign(x) = 0

%         x = a - gamma*lambda
%         x = a + gamma*lambda


sizea = size(a);
Proj = @(x)min(max(x,-R),R);
a = a(:);
x1 = -R*ones(size(a));
x2 = R*ones(size(a));
x3 = Proj(a);
x4 = Proj(a - gamma*lambda);
x5 = Proj(a + gamma*lambda);

handle_f = @(x)0.5.*((x-a).*(x-a)) + gamma*lambda*max(delta,abs(x));
x = ComputeBestRoot5(handle_f,x1,x2,x3,x4,x5);
x = reshape(x,sizea);



function [x] = g_threadholding_l1(a,lambda)
% solving the following OP:
% min_{x} 0.5 ||x - a||^2 + lambda * sum(abs(x))
x = sign(a).*max(0,abs(a)-lambda);
x = max(x,-1);
x = min(x,1);



function [x,fobjs,ts] = AEPGSpider(x,HandleObjSmoothMiniBatch,HandleObjNonSmooth,hProx,timeLimit,timeIntervel,Lipschitz,dat_m,theta)

hObjF = @(x)HandleObjSmoothMiniBatch(x,[1:dat_m]) + HandleObjNonSmooth(x);
initt = clock;
last_rec_clock = initt;
ts = []; fobjs = [];
fobjs = [fobjs;hObjF(x)];
ts = [ts;etime(clock,initt)];

v0 = Lipschitz*0.5;
beta = 1;
alpha = beta*0.01;

v = v0*ones(size(x));

grad = 0*ones(size(x));
sigma = theta;
y = x;
y_old = y;

bbb = round(sqrt(dat_m));
qqq = round(sqrt(dat_m));

for iter = 1:3000000000
    
    Batch = randperm(dat_m,bbb); 
    [~,grad_bat] = HandleObjSmoothMiniBatch(y,Batch);
    [~,grad_bat_old] = HandleObjSmoothMiniBatch(y_old,Batch);
    if (mod(iter,qqq)==0)
        [~,grad] = HandleObjSmoothMiniBatch(y,[1:dat_m]);
    else
        grad = grad + grad_bat - grad_bat_old;
    end
    
    xt = x;
    x  = hProx(y-grad./v,v);
    d = x - xt;
    
    r = v.*d;
    rr = r.*r;
    s = alpha*sum(rr(:)) + beta*rr;
    vt = v;
    v = sqrt(v.*v + s);
    
    sigma = theta*(1-sigma)*min(min(vt./v));
    y_old = y;
    y = x + sigma*d;
    
    cur_clock = clock;
    if(etime(cur_clock,last_rec_clock) > timeIntervel)
        F    = hObjF(x);
        fprintf('iter:%d, minv:%.10e, Diff :%.3e, F: %.3e\n',...
            iter,    min(v(:)),  norm(d,'fro'),   F);
        ElasTime =  etime(cur_clock,initt);
        fobjs  = [fobjs;F];
        ts = [ts;ElasTime];
        last_rec_clock = cur_clock;
        if ElasTime > timeLimit
            break;
        end
    end
end





function [x,fobjs,ts] = SpiderSGP(x,HandleObjSmoothMiniBatch,HandleObjNonSmooth,hProx,timeLimit,timeIntervel,Lipschitz,dat_m,reg_para_lambda,reg_para_delta,reg_para_R)
% Subgradient Projection

hObjF = @(x)HandleObjNonSmooth(x) + HandleObjSmoothMiniBatch(x,[1:dat_m]);
initt = clock;
last_rec_clock = initt;
ts = []; fobjs = [];
fobjs = [fobjs;hObjF(x)];
ts = [ts;etime(clock,initt)];

% Key parameters
v0 = Lipschitz;

beta = 0.01;
alpha = beta*0.01;

v = v0*ones(size(x));

grad = 0*ones(size(x));
y = x;
y_old = y;
bbb = round(sqrt(dat_m));
qqq = round(sqrt(dat_m));
grad_bat = zeros(size(x));

for iter = 1:3000000000
    
    
    Batch = randperm(dat_m,bbb); 
    [~,grad_bat_part2] = CappedL1(x,reg_para_lambda,reg_para_delta);
    [~,grad_bat] = HandleObjSmoothMiniBatch(y,Batch);
    [~,grad_bat_old] = HandleObjSmoothMiniBatch(y_old,Batch);
    grad_bat = grad_bat + grad_bat_part2;
    grad_bat_old = grad_bat_old + grad_bat_part2;
    if (mod(iter,qqq)==0)
        [~,grad] = HandleObjSmoothMiniBatch(y,[1:dat_m]);
    else
        grad = grad + grad_bat - grad_bat_old;
    end
    
    xt = x;
    x  = y-grad./v;
    x = box_proj(x,-reg_para_R,reg_para_R);
    d = x - xt;
    y_old = y;
    y = x ;
    
    cur_clock = clock;
    if(etime(cur_clock,last_rec_clock) > timeIntervel)
        F    = hObjF(x);
        fprintf('iter:%d, minv:%.10e, Diff :%.3e, F: %.3e\n',...
            iter,    min(v(:)),  norm(d,'fro'), F);
        ElasTime =  etime(cur_clock,initt);
        fobjs  = [fobjs;F];
        ts = [ts;ElasTime];
        last_rec_clock = cur_clock;
        if ElasTime > timeLimit
            break;
        end
    end
end

function [x,fobjs,ts] = ProximalSARAH(x,HandleObjSmoothMiniBatch,HandleObjNonSmooth,hProx,timeLimit,timeIntervel,Lipschitz,dat_m)

hObjF = @(x)HandleObjNonSmooth(x) + HandleObjSmoothMiniBatch(x,[1:dat_m]);
initt = clock;
last_rec_clock = initt;
ts = []; fobjs = [];
fobjs = [fobjs;hObjF(x)];
ts = [ts;etime(clock,initt)];

% Key parameters
% v0 = Lipschitz;

% beta = 0.01;
% alpha = beta*0.01;

v = Lipschitz*ones(size(x));

% grad = 0*ones(size(x));


bbb = round(sqrt(dat_m));
qqq = round(sqrt(dat_m));
% grad_bat = zeros(size(x));
% gamma = 0.5;

for iter = 1:3000000000
    
    [~,grad] = HandleObjSmoothMiniBatch(x,[1:dat_m]);
    x_old = x;
    x = hProx(x-grad./v,v);
    
    for in = 1:qqq
        B = randperm(dat_m,bbb);
        [~,g_x_old] = HandleObjSmoothMiniBatch(x_old,B);
        [~,g_x] = HandleObjSmoothMiniBatch(x,B);
        grad = grad + g_x - g_x_old;
        x = hProx(x-grad./v,v);
    end
    
    d = x_old-x;
    
    cur_clock = clock;
    if(etime(cur_clock,last_rec_clock) > timeIntervel)
        F    = hObjF(x);
        fprintf('iter:%d, minv:%.10e, Diff :%.3e, F: %.3e\n',...
            iter,    min(v(:)),  norm(d,'fro'), F);
        ElasTime =  etime(cur_clock,initt);
        fobjs  = [fobjs;F];
        ts = [ts;ElasTime];
        last_rec_clock = cur_clock;
        if ElasTime > timeLimit
            break;
        end
    end
end


function [x,fobjs,ts] = SpiderBoost(x,HandleObjSmoothMiniBatch,HandleObjNonSmooth,hProx,timeLimit,timeIntervel,Lipschitz,dat_m)

hObjF = @(x)HandleObjNonSmooth(x) + HandleObjSmoothMiniBatch(x,[1:dat_m]);
initt = clock;
last_rec_clock = initt;
ts = []; fobjs = [];
fobjs = [fobjs;hObjF(x)];
ts = [ts;etime(clock,initt)];

% Key parameters
v0 = Lipschitz;

beta = 0.01;
alpha = beta*0.01;

v = v0*ones(size(x));

grad = 0*ones(size(x));
y = x;
y_old = y;

bbb = round(sqrt(dat_m));
qqq = round(sqrt(dat_m));

for iter = 1:3000000000
    
    
    Batch = randperm(dat_m,bbb); 
    [~,grad_bat] = HandleObjSmoothMiniBatch(y,Batch);
    [~,grad_bat_old] = HandleObjSmoothMiniBatch(y_old,Batch);
    if (mod(iter,qqq)==0)
        [~,grad] = HandleObjSmoothMiniBatch(y,[1:dat_m]);
    else
        grad = grad + grad_bat - grad_bat_old;
    end
    
    xt = x;
    x  = hProx(y-grad./v,v);
    d = x - xt;
    y_old = y;
    y = x ;
    
    cur_clock = clock;
    if(etime(cur_clock,last_rec_clock) > timeIntervel)
        F    = hObjF(x);
        fprintf('iter:%d, minv:%.10e, Diff :%.3e, F: %.3e\n',...
            iter,    min(v(:)),  norm(d,'fro'), F);
        ElasTime =  etime(cur_clock,initt);
        fobjs  = [fobjs;F];
        ts = [ts;ElasTime];
        last_rec_clock = cur_clock;
        if ElasTime > timeLimit
            break;
        end
    end
end

function [x,fobjs,ts] = SpiderBoostM(x,HandleObjSmoothMiniBatch,HandleObjNonSmooth,hProx,timeLimit,timeIntervel,Lipschitz,dat_m)

hObjF = @(x)HandleObjNonSmooth(x) + HandleObjSmoothMiniBatch(x,[1:dat_m]);
initt = clock;
last_rec_clock = initt;
ts = []; fobjs = [];
fobjs = [fobjs;hObjF(x)];
ts = [ts;etime(clock,initt)];

% Key parameters
v0 = Lipschitz;

beta = 0.01;
alpha = beta*0.01;

v = v0*ones(size(x));

grad = 0*ones(size(x));
y = x;
y_old = y;

bbb = round(sqrt(dat_m));
qqq = round(sqrt(dat_m));
grad_bat = zeros(size(x));

for iter = 1:3000000000
    
    Batch = randperm(dat_m,bbb); 
    [~,grad_bat] = HandleObjSmoothMiniBatch(y,Batch);
    [~,grad_bat_old] = HandleObjSmoothMiniBatch(y_old,Batch);
    if (mod(iter,qqq)==0)
        [~,grad] = HandleObjSmoothMiniBatch(y,[1:dat_m]);
    else
        grad = grad + grad_bat - grad_bat_old;
    end
    
    xt = x;
    x  = hProx(y-grad./v,v);
    d = x - xt;
    sigma = 0.5;
    y_old = y;
    y = x + sigma*d;
    
    cur_clock = clock;
    if(etime(cur_clock,last_rec_clock) > timeIntervel)
        F    = hObjF(x);
        fprintf('iter:%d, minv:%.10e, Diff :%.3e, F: %.3e\n',...
            iter,    min(v(:)),  norm(d,'fro'), F);
        ElasTime =  etime(cur_clock,initt);
        fobjs  = [fobjs;F];
        ts = [ts;ElasTime];
        last_rec_clock = cur_clock;
        if ElasTime > timeLimit
            break;
        end
    end
end



