function [w, primal_loss, dual_loss, gap, criterion] = Consensus_ADMM(Xdata, Ydata, lambda, beta,  option)
% Inputs:
%         data:   the type is a structure,
%                
%      options:   algorithm parameters
%        epoch:   total communication round
%         loss:  'ls' or 'svm'
%      penalty:  'l1' or 'l2'
% Outputs:
%        infos:   iteration information contains training_loss,
%                  per round

rng('default');

numWorkers=length(Xdata); 
nk=zeros(numWorkers,1);
v_agent=cell(numWorkers,1);
for k=1:numWorkers
    [nk(k),~]=size(Xdata{k});
    v_agent{k}=zeros(nk(k),1);
end
n=sum(nk);


[~,d]=size(Xdata{1});

epoch=100; 
eval_every=1;

% User's option
%------------------------------------------------
if nargin > 4  %  then paramstruct is an argument
    if isfield(option,'epoch')
        epoch = option.epoch; 
    end
    if isfield(option,'penalty')
        penalty = option.penalty; 
    end 
    if isfield(option,'loss')
        obj = option.loss; 
    end
end



% Aggregate full dataset (only for loss/criterion evaluation)
X = vertcat(Xdata{:});
Y = vertcat(Ydata{:});

primal_loss=zeros(epoch, 1);
dual_loss=primal_loss;
criterion=primal_loss;
gap=primal_loss;
w=zeros(d,1);
%eta=0.5*numWorkers; % our proposed method parameter



% display iteration information
fprintf('----------------Training with {%3.0d} workers------------\n', numWorkers);
fprintf('  round   primal loss     dual loss         gap    relative KKT residual     \n');


sharedXdata = parallel.pool.Constant(Xdata);
sharedYdata = parallel.pool.Constant(Ydata);


if matches(penalty,'l1')
    % l1-norm proximal mapping operator
    Proxg=@(z, lambda) sign(z).*max(abs(z)-lambda,0);
else
    % l2-norm proximal mapping operator
    Proxg=@(z, lambda) z/(lambda+1);
end

temp_prev=X'*vertcat(v_agent{:});


for i=1:epoch    
    
    % On All Agents:
    for k=1:numWorkers
        
        if matches(obj,'ls')
            v_agent{k}=ls_Localupdate(sharedXdata.Value{k}, sharedYdata.Value{k}, w, v_agent{k}, n, beta);
        else
            v_agent{k}=svm_Localupdate(sharedXdata.Value{k}, sharedYdata.Value{k}, w, v_agent{k}, n, beta);
        end
    
    end

    temp=X'*vertcat(v_agent{:});
    % On Parameter Server:
    w=Proxg( w+(temp_prev-2*temp)/(n*beta*numWorkers) , lambda/(beta*numWorkers) );
    % 
    
    if matches(obj,'ls')

        if matches(penalty,'l1')
            primal = norm(Y-(X*w))^2/(2*n)+lambda*norm(w, 1);
            dual = -mean(Y.*vertcat(v_agent{:}))-norm(vertcat(v_agent{:}))^2/(2*n);
            kkt = norm(w-Proxg(w-X'*(X*w-Y)/n,lambda )); % relative KKT residual
        else
            primal = norm(Y-(X*w))^2/(2*n)+lambda/2*norm(w)^2;
            dual = -mean(Y.*vertcat(v_agent{:}))-norm(vertcat(v_agent{:}))^2/(2*n)-lambda/2*norm(temp/(n*lambda) )^2;
            kkt = norm(w-Proxg(w-X'*(X*w-Y)/n,lambda )); % relative KKT residual
        end

    else

        if matches(penalty,'l1')
            primal=mean(max(0, 1-Y.*(X*w)) )+lambda*norm(w, 1);
            dual=-mean(Y.*vertcat(v_agent{:}));
            kkt = (norm(X*w-prox_l1svm(X*w+vertcat(v_agent{:}),  Y))+norm( w- Proxg(w-temp/n ,lambda)))/(1+norm(w)+norm(vertcat(v_agent{:})));
        else
            primal=mean(max(0, 1-Y.*(X*w)) )+lambda/2*norm(w)^2;
            dual=-mean(Y.*vertcat(v_agent{:}))-lambda/2*norm(temp/(n*lambda) )^2;
            kkt = (norm(X*w-prox_l1svm(X*w+vertcat(v_agent{:}),  Y))+norm( w- Proxg(w-temp/n ,lambda)))/(1+norm(w)+norm(vertcat(v_agent{:})));
        end

    end

    primal_loss(i)=primal;
    dual_loss(i)=dual;
    gap(i)=abs(primal-dual)/(1+abs(primal)+abs(dual));
    criterion(i)=kkt;
    temp_prev=temp;

    if mod(i,eval_every)==0        
        fprintf('%5d  \t %5d \t %5d \t %5d \t  %5d \t  \n', ...
                i, primal, dual, gap(i),criterion(i));
    end    
  

end




end



function u1=prox_l1svm(u,  y)
    
    u1= u - y.* max(-1, min(0, y.*u-1)  );

end





function v=svm_Localupdate(X, Y, w, v_old, n, beta)

options = optimoptions('quadprog','Display','off','OptimalityTolerance',1e-10);
[nk,~]=size(X);
lb=zeros(nk,1);
ub=zeros(nk,1);
lb(Y==1)=-1;
ub(Y==-1)=1;
H=sparse((X*X')/(n*beta));
f=Y-1/(n*beta)*(X*X')*v_old-X*w;
v = quadprog(H,f,[],[],[],[],lb,ub,v_old,options);

end

function v_new=ls_Localupdate(X, Y, w, v_old, n, beta)

[nk,~]=size(X);
v_new=(1/(n*beta)*(X*X')+eye(nk))\(1/(n*beta)*(X*X')*v_old+X*w-Y);


end




