function [w, primal_loss, dual_loss, gap, criterion] = PJADMM(Xdata, Ydata, lambda, rho, eta, option)
% Inputs:
%         data:   the type is a structure,
%                
%      options:   algorithm parameters
%        epoch:   total communication round
%         loss:  'ls' or 'svm'
%      penalty:  'l1' or 'l2'
% Outputs:
%        infos:   iteration information contains training_loss,
%                  per round

rng('default');

numWorkers=length(Xdata); 
nk=zeros(numWorkers,1);
v_agent=cell(numWorkers,1);
for k=1:numWorkers
    [nk(k),~]=size(Xdata{k});
    v_agent{k}=zeros(nk(k),1);
end
n=sum(nk);


[~,d]=size(Xdata{1});

epoch=100; 
eval_every=1;

% User's option
%------------------------------------------------
if nargin > 5  %  then paramstruct is an argument
    if isfield(option,'epoch')
        epoch = option.epoch; 
    end
    if isfield(option,'penalty')
        penalty = option.penalty; 
    end 
    if isfield(option,'loss')
        obj = option.loss; 
    end
end



% Aggregate full dataset (only for loss/criterion evaluation)
X = vertcat(Xdata{:});
Y = vertcat(Ydata{:});

primal_loss=zeros(epoch, 1);
dual_loss=primal_loss;
criterion=primal_loss;
gap=primal_loss;
w=zeros(d,1);
beta=w;
%eta=0.5*numWorkers; % our proposed method parameter



% display iteration information
fprintf('----------------Training with {%3.0d} workers------------\n', numWorkers);
fprintf('  round   primal loss     dual loss    gap    relative KKT residual     \n');


sharedXdata = parallel.pool.Constant(Xdata);
sharedYdata = parallel.pool.Constant(Ydata);


if matches(penalty,'l1')
    % l1-norm proximal mapping operator
    Proxg=@(z, lambda) sign(z).*max(abs(z)-lambda,0);
else
    % l2-norm proximal mapping operator
    Proxg=@(z, lambda) z/(lambda+1);
end



for i=1:epoch    
    

    % On All Agents:
    for k=1:numWorkers
        
        if matches(obj,'ls')
            v_agent{k}=ls_Localupdate(sharedXdata.Value{k}, sharedYdata.Value{k}, beta, v_agent{k}, n, rho, eta);
        else
            v_agent{k}=svm_Localupdate(sharedXdata.Value{k}, sharedYdata.Value{k}, beta, v_agent{k}, n, rho, eta);
        end
    
    end


    % On Parameter Server:
    temp=X'*vertcat(v_agent{:});
    dw=Proxg( w-rho/n*temp , rho*lambda )-w;
    w=w+dw;
    beta=-dw-w; % transferred parameters
    % 
    
    if matches(obj,'ls')

        if matches(penalty,'l1')
            primal = norm(Y-(X*w))^2/(2*n)+lambda*norm(w, 1);
            dual = -mean(Y.*vertcat(v_agent{:}))-norm(vertcat(v_agent{:}))^2/(2*n);
            kkt = norm(w-Proxg(w-X'*(X*w-Y)/n,lambda )); % relative KKT residual
        else
            primal = norm(Y-(X*w))^2/(2*n)+lambda/2*norm(w)^2;
            dual = -mean(Y.*vertcat(v_agent{:}))-norm(vertcat(v_agent{:}))^2/(2*n)-lambda/2*norm(temp/(n*lambda) )^2;
            kkt = norm(w-Proxg(w-X'*(X*w-Y)/n,lambda )); % relative KKT residual
        end

    else

        if matches(penalty,'l1')
            primal=mean(max(0, 1-Y.*(X*w)) )+lambda*norm(w, 1);
            dual=-mean(Y.*vertcat(v_agent{:}));
            kkt = (norm(X*w-prox_l1svm(X*w+vertcat(v_agent{:}),  Y))+norm( w- Proxg(w-temp/n ,lambda)))/(1+norm(w)+norm(vertcat(v_agent{:})));
        else
            primal=mean(max(0, 1-Y.*(X*w)) )+lambda/2*norm(w)^2;
            dual=-mean(Y.*vertcat(v_agent{:}))-lambda/2*norm(temp/(n*lambda) )^2;
            kkt = (norm(X*w-prox_l1svm(X*w+vertcat(v_agent{:}),  Y))+norm( w- Proxg(w-temp/n ,lambda)))/(1+norm(w)+norm(vertcat(v_agent{:})));
        end

    end

    primal_loss(i)=primal;
    dual_loss(i)=dual;
    gap(i)=abs(primal-dual)/(1+abs(primal)+abs(dual));
    criterion(i)=kkt;
    
    if mod(i,eval_every)==0        
        fprintf('%5d  \t %5d \t %5d \t %5d \t  %5d \t  \n', ...
                i, primal, dual, gap(i),criterion(i));
    end    
  

end




end



function u1=prox_l1svm(u,  y)
    
    u1= u - y.* max(-1, min(0, y.*u-1)  );

end





function v=svm_Localupdate(X, Y, beta, v_old, n, rho, eta)

options = optimoptions('quadprog','Display','off','OptimalityTolerance',1e-10);
[nk,~]=size(X);
lb=zeros(nk,1);
ub=zeros(nk,1);
lb(Y==1)=-1;
ub(Y==-1)=1;
H=sparse(rho*eta*(X*X')/n);
f=Y-rho*eta/n*(X*X')*v_old+X*beta;
v = quadprog(H,f,[],[],[],[],lb,ub,v_old,options);

end

function v_new=ls_Localupdate(X, Y, beta, v_old, n, rho, eta)

[nk,~]=size(X);
v_new=(rho*eta/n*(X*X')+eye(nk))\(rho*eta/n*(X*X')*v_old-X*beta-Y);


end





% function [v]=Localupdate(X, Y, beta, v_old, n, rho, p)
% 
% [nk,~]=size(X);
% v=v_old;
% u=X'*(v_old-v)-n*p*beta;
% 
% Xnorm=sum(X.^2,2);
% epoch=50;
% 
% loss_before=v'*Y+rho/(2*n*p)*norm(u)^2;
% for i=1:epoch
%     index_number=randperm(nk);
% 
% 
%     for j=index_number
%         temp=Xnorm(j);
%         delta=Y(j)*max(-1, min(0,  v(j)*Y(j)+X(j,:)*u*Y(j)/temp -n*p/(rho*temp)) )-v(j);
%         v(j)=v(j)+delta;
%         u=u-X(j,:)'*delta;
%     end
% 
%     loss_after=v'*Y+rho/(2*n*p)*norm(u)^2;
%     error=abs(loss_after-loss_before)/abs(loss_before);
%     loss_before=loss_after;
% 
%     if error<=1e-6
%         break;
%     end
% 
% end
% 
% 
% end
% 
% 
