function [w, primal_loss, dual_loss, gap, criterion] = DLADMM(Xdata, Ydata, lambda, rho, eta, option)
% Inputs:
%         data:   the type is a structure,
%                
%      options:   algorithm parameters
%        epoch:   total communication round
%         loss:  'ls' or 'svm'
%      penalty:  'l1' or 'l2'
% Outputs:
%        infos:   iteration information contains training_loss,
%                  per round

rng('default');

numWorkers=length(Xdata); 
nk=zeros(numWorkers,1);
v_agent=cell(numWorkers,1);
for k=1:numWorkers
    [nk(k),~]=size(Xdata{k});
    v_agent{k}=zeros(nk(k),1);
end
n=sum(nk);


[~,d]=size(Xdata{1});

epoch=100; 
eval_every=1;

% User's option
%------------------------------------------------
if nargin > 5  %  then paramstruct is an argument
    if isfield(option,'epoch')
        epoch = option.epoch; 
    end
    if isfield(option,'penalty')
        penalty = option.penalty; 
    end 
    if isfield(option,'loss')
        obj = option.loss; 
    end
end



% Aggregate full dataset (only for loss/criterion evaluation)
X = vertcat(Xdata{:});
Y = vertcat(Ydata{:});

primal_loss=zeros(epoch, 1);
dual_loss=primal_loss;
criterion=primal_loss;
gap=primal_loss;
w=zeros(d,1);
beta=w;
%eta=0.5*numWorkers; % our proposed method parameter



% display iteration information
fprintf('----------------Training with {%3.0d} workers------------\n', numWorkers);
fprintf('  round   primal loss     dual loss         gap    relative KKT residual     \n');


sharedXdata = parallel.pool.Constant(Xdata);
sharedYdata = parallel.pool.Constant(Ydata);


if matches(penalty,'l1')
    % l1-norm proximal mapping operator
    Proxg=@(z, lambda) sign(z).*max(abs(z)-lambda,0);
else
    % l2-norm proximal mapping operator
    Proxg=@(z, lambda) z/(lambda+1);
end



for i=1:epoch    
    

    % On All Agents:
    for k=1:numWorkers
        
        if matches(obj,'ls')
            v_agent{k}=ls_Localupdate(sharedXdata.Value{k}, sharedYdata.Value{k}, beta, v_agent{k}, n, rho, eta);
        else
            v_agent{k}=svm_Localupdate(sharedXdata.Value{k}, sharedYdata.Value{k}, beta, v_agent{k}, n, rho, eta);
        end
    
    end


    % On Parameter Server:
    temp=X'*vertcat(v_agent{:});
    dw=Proxg( w-rho/n*temp , rho*lambda )-w;
    w=w+dw;
    beta=-dw-w; % transferred parameters
    % 
    
    if matches(obj,'ls')

        if matches(penalty,'l1')
            primal = norm(Y-(X*w))^2/(2*n)+lambda*norm(w, 1);
            dual = -mean(Y.*vertcat(v_agent{:}))-norm(vertcat(v_agent{:}))^2/(2*n);
            kkt = norm(w-Proxg(w-X'*(X*w-Y)/n,lambda )); % relative KKT residual
        else
            primal = norm(Y-(X*w))^2/(2*n)+lambda/2*norm(w)^2;
            dual = -mean(Y.*vertcat(v_agent{:}))-norm(vertcat(v_agent{:}))^2/(2*n)-lambda/2*norm(temp/(n*lambda) )^2;
            kkt = norm(w-Proxg(w-X'*(X*w-Y)/n,lambda )); % relative KKT residual
        end

    else

        if matches(penalty,'l1')
            primal=mean(max(0, 1-Y.*(X*w)) )+lambda*norm(w, 1);
            dual=-mean(Y.*vertcat(v_agent{:}));
            kkt = norm(X*w-prox_l1svm(X*w+vertcat(v_agent{:}), 1, Y))+norm( w- Proxg(w-temp/n ,lambda));
        else
            primal=mean(max(0, 1-Y.*(X*w)) )+lambda/2*norm(w)^2;
            dual=-mean(Y.*vertcat(v_agent{:}))-lambda/2*norm(temp/(n*lambda) )^2;
            kkt = norm(X*w-prox_l1svm(X*w+vertcat(v_agent{:}), 1,  Y))+norm( w- Proxg(w-temp/n ,lambda));
        end

    end

    primal_loss(i)=primal;
    dual_loss(i)=dual;
    gap(i)=abs(primal-dual)/(1+abs(primal)+abs(dual));
    criterion(i)=kkt;
    
    if mod(i,eval_every)==0        
        fprintf('%5d  \t %5d \t %5d \t %5d \t  %5d \t  \n', ...
                i, primal, dual, gap(i),criterion(i));
    end    
  

end




end



function u1=prox_l1svm(u, t, y)
    
    u1= u - t.*prox_l1svm_conj(u./t, 1/t,  y);

end


function u1=prox_l1svm_conj(u, t, y)

    u1=y.*max(-1, min(0, y.*u-t)  );

end

function u1=prox_ls(u, t, y)

    u1=(u + t.*y)./(t + 1);

end

function u1=prox_ls_conj(u, t, y)

    u1=u-t.*prox_ls(u./t, 1./t, y);

end

function v=svm_Localupdate(X, Y, beta, v_old, n, rho, eta)

% options = optimoptions('quadprog','Display','off','OptimalityTolerance',1e-10);
% [nk,~]=size(X);
% lb=zeros(nk,1);
% ub=zeros(nk,1);
% lb(Y==1)=-1;
% ub(Y==-1)=1;
% H=sparse(rho*eta*(eye(nk))/n);
% f=Y-rho*eta/n*v_old+X*beta;
% v = quadprog(H,f,[],[],[],[],lb,ub,v_old,options);

temp=v_old-n/(rho*eta)*X*beta;
v=prox_l1svm_conj(temp, n/(rho*eta), Y);

end

function v=ls_Localupdate(X, Y, beta, v_old, n, rho, eta)


temp=v_old-n/(rho*eta)*X*beta;
v=prox_ls_conj(temp, n/(rho*eta), Y);

end




