function [w, primal_loss, dual_loss, gap, criterion] = CoCoA(Xdata, Ydata, lambda, sigma, gamma, option)
% Inputs:
%         data:   the type is a structure,
%                
%      options:   algorithm parameters
%        epoch:   total communication round
%         loss:  'ls' or 'svm'
% Outputs:
%        infos:   iteration information contains training_loss,
%                  per round

rng('default');

numWorkers=length(Xdata); 
nk=zeros(numWorkers,1);
v_agent=cell(numWorkers,1);
for k=1:numWorkers
    [nk(k),~]=size(Xdata{k});
    v_agent{k}=zeros(nk(k),1);
end
n=sum(nk);


[~,d]=size(Xdata{1});

epoch=100; 
eval_every=1;

% User's option
%------------------------------------------------
if nargin > 5  %  then paramstruct is an argument
    if isfield(option,'epoch')
        epoch = option.epoch; 
    end
    if isfield(option,'loss')
        obj = option.loss; 
    end
end



% Aggregate full dataset (only for loss/criterion evaluation)
X = vertcat(Xdata{:});
Y = vertcat(Ydata{:});
primal_loss=zeros(epoch, 1);
dual_loss=primal_loss;
criterion=primal_loss;
gap=primal_loss;
w=zeros(d,1);


% l2-norm proximal mapping operator
Proxg=@(z, lambda) z/(lambda+1);


% display iteration information
fprintf('----------------Training with {%3.0d} workers------------\n', numWorkers);
fprintf('  round   primal loss     dual loss    gap    relative KKT residual     \n');


sharedXdata = parallel.pool.Constant(Xdata);
sharedYdata = parallel.pool.Constant(Ydata);



for i=1:epoch    
    

    % On All Agents:
    for k=1:numWorkers
        
        if matches(obj,'ls')
            v_agent{k}=ls_Localupdate(sharedXdata.Value{k}, sharedYdata.Value{k}, w, v_agent{k}, n, lambda, sigma, gamma);
        else
            v_agent{k}=svm_Localupdate(sharedXdata.Value{k}, sharedYdata.Value{k}, w, v_agent{k}, n, lambda, sigma, gamma);
        end
    
    end


    % On Parameter Server:
    temp=X'*vertcat(v_agent{:});
    w=-temp/(n*lambda);
    % 
    
    if matches(obj,'ls')
        
        primal = norm(Y-(X*w))^2/(2*n)+lambda/2*norm(w)^2;
        dual = -mean(Y.*vertcat(v_agent{:}))-norm(vertcat(v_agent{:}))^2/(2*n)-lambda/2*norm(temp/(n*lambda) )^2;
        kkt = norm(w-Proxg(w-X'*(X*w-Y)/n,lambda )); % relative KKT residua

    else
        
        primal = mean(max(0, 1-Y.*(X*w)) )+lambda/2*norm(w)^2;
        dual = -mean(Y.*vertcat(v_agent{:}))-lambda/2*norm(temp/(n*lambda) )^2;
        kkt = norm(X*w-prox_l1svm(X*w+vertcat(v_agent{:}),  Y))+norm( w- Proxg(w-temp/n ,lambda));
     
    end

    primal_loss(i)=primal;
    dual_loss(i)=dual;
    gap(i)=abs(primal-dual)/(1+abs(primal)+abs(dual));
    criterion(i)=kkt;
    
    if mod(i,eval_every)==0        
        fprintf('%5d  \t %5d \t %5d \t %5d \t  %5d \t  \n', ...
                i, primal, dual, gap(i),criterion(i));
    end    
  

end




end



function u1=prox_l1svm(u,  y)
    
    u1= u - y.* max(-1, min(0, y.*u-1)  );

end





function v_new=svm_Localupdate(X, Y, w, v_old, n, lambda, sigma, gamma)

options = optimoptions('quadprog','Display','off','OptimalityTolerance',1e-10);
[nk,~]=size(X);
lb=zeros(nk,1);
ub=zeros(nk,1);
lb(Y==1)=-1;
ub(Y==-1)=1;
H=sparse(sigma*(X*X')/(n*lambda));
f=Y-sigma/(n*lambda)*(X*X')*v_old-X*w;
v_tilde = quadprog(H,f,[],[],[],[],lb,ub,v_old,options);
v_new=v_old+gamma*(v_tilde-v_old);


end

function v_new=ls_Localupdate(X, Y, w, v_old, n, lambda, sigma, gamma)

[nk,~]=size(X);
v_tilde=(sigma/(n*lambda)*(X*X')+eye(nk))\(X*w+sigma/(n*lambda)*(X*X')*v_old-Y);
v_new=v_old+gamma*(v_tilde-v_old);


end




