function [w, primal_loss, dual_loss, gap, criterion] = Primal_LADMM(Xdata, Ydata, lambda, beta, tau, option)
% Inputs:
%         data:   the type is a structure containg data.train_data,
%                
%      options:   algorithm parameters
% Outputs:
%        infos:   iteration information contains training_loss,
%                  per round

rng('default');

numWorkers=length(Xdata); 
nk=zeros(numWorkers,1);
v_agent=cell(numWorkers,1);
for k=1:numWorkers
    [nk(k),~]=size(Xdata{k});
    v_agent{k}=zeros(nk(k),1);
end
n=sum(nk);
[~,d]=size(Xdata{1});


epoch=100; 
eval_every=1;

% User's option
%------------------------------------------------
if nargin > 4  %  then paramstruct is an argument
    if isfield(option,'epoch')
        epoch = option.epoch; 
    end
    if isfield(option,'penalty')
        penalty = option.penalty; 
    end 
    if isfield(option,'loss')
        obj = option.loss; 
    end
end

% Aggregate full dataset (only for loss/criterion evaluation)
X = vertcat(Xdata{:});
Y = vertcat(Ydata{:});

primal_loss=zeros(epoch, 1);
dual_loss=primal_loss;
criterion=primal_loss;
gap=primal_loss;
w=zeros(d,1);
%u_agent=v_agent;

% l1-norm proximal mapping operator
if matches(penalty,'l1')
    % l1-norm proximal mapping operator
    Proxg=@(z, lambda) sign(z).*max(abs(z)-lambda,0);
else
    % l2-norm proximal mapping operator
    Proxg=@(z, lambda) z/(lambda+1);
end



% display iteration information
fprintf('----------------Training with {%3.0d} workers------------\n', numWorkers);
fprintf('  round   primal loss     dual loss    gap    relative KKT residual     \n');

dv=v_agent;
for i=1:epoch    
    
    
    % On All Agents:
    for k=1:numWorkers

        if matches(obj,'ls')
            [dv{k}, v_agent{k}] = ls_Localupdate(Xdata{k}, Ydata{k}, w, v_agent{k}, n, beta/tau);
        else
            [dv{k}, v_agent{k}] = svm_Localupdate(Xdata{k}, Ydata{k}, w, v_agent{k}, n, beta/tau);
        end
    
    end

    % On Parameter Server:
    w=Proxg(w-X'*vertcat(dv{:})/(n*beta*numWorkers), lambda/(beta*numWorkers));
    
    temp=X'*vertcat(v_agent{:});
    if matches(obj,'ls')

        if matches(penalty,'l1')
            primal = norm(Y-(X*w))^2/(2*n)+lambda*norm(w, 1);
            dual = -mean(Y.*vertcat(v_agent{:}))-norm(vertcat(v_agent{:}))^2/(2*n);
            kkt = norm(w-Proxg(w-X'*(X*w-Y)/n,lambda )); % relative KKT residual
        else
            primal = norm(Y-(X*w))^2/(2*n)+lambda/2*norm(w)^2;
            dual = -mean(Y.*vertcat(v_agent{:}))-norm(vertcat(v_agent{:}))^2/(2*n)-lambda/2*norm(temp/(n*lambda) )^2;
            kkt = norm(w-Proxg(w-X'*(X*w-Y)/n,lambda )); % relative KKT residual
        end

    else

        if matches(penalty,'l1')
            primal=mean(max(0, 1-Y.*(X*w)) )+lambda*norm(w, 1);
            dual=-mean(Y.*vertcat(v_agent{:}));
            kkt = (norm(X*w-prox_l1svm(X*w+vertcat(v_agent{:}), 1, Y))+norm( w- Proxg(w-temp/n ,lambda)))/(1+norm(w)+norm(vertcat(v_agent{:})));
        else
            primal=mean(max(0, 1-Y.*(X*w)) )+lambda/2*norm(w)^2;
            dual=-mean(Y.*vertcat(v_agent{:}))-lambda/2*norm(temp/(n*lambda) )^2;
            kkt = (norm(X*w-prox_l1svm(X*w+vertcat(v_agent{:}), 1, Y))+norm( w- Proxg(w-temp/n ,lambda)))/(1+norm(w)+norm(vertcat(v_agent{:})));
        end

    end

    primal_loss(i)=primal;
    dual_loss(i)=dual;
    gap(i)=abs(primal-dual)/(1+abs(primal)+abs(dual));
    criterion(i)=kkt;
    
    if mod(i,eval_every)==0        
        fprintf('%5d  \t %5d \t %5d \t %5d \t  %5d \t  \n', ...
                i, primal, dual, gap(i),criterion(i));
    end    
    
end

end



function u1=prox_l1svm(u, t, y)
    
    u1= u - t.*prox_l1svm_conj(u./t, 1/t,  y);

end


function u1=prox_l1svm_conj(u, t, y)

    u1=y.*max(-1, min(0, y.*u-t)  );

end


function [temp, v_new]=svm_Localupdate(X, Y, w, v_old, n, beta)

    % temp=X*w+v_old/(n*beta);
    % u_new=prox_l1svm(temp, 1/(n*beta), Y);
    % v_new=v_old+n*beta*(X*w-u_new);

    temp=(n*beta)*X*w+v_old;
    v_new=prox_l1svm_conj(temp, n*beta, Y);
    temp=2*v_new-v_old;

end


function [temp, v_new]=ls_Localupdate(X, Y, w, v_old, n, beta)

    temp=X*w+v_old/(n*beta);
    u_new=(Y+n*beta*temp)/(n*beta+1);
    v_new=v_old+n*beta*(X*w-u_new);
    temp=2*v_new-v_old;
    
end

