function [w, loss, criterion,dist] = PG_EXTRA(Xdata, Ydata, lambda, alpha, option,P,opt)

rng('default');
numWorkers=length(Xdata); 
u_agent=cell(numWorkers,1);

A=(P+eye(numWorkers))/2;
n=0;
[~,d]=size(Xdata{1});

for k=1:numWorkers
    u_agent{k}=zeros(d,1);
end
for k=1:numWorkers
    [temp,~]=size(Xdata{k});
    n=n+temp;
end

epoch=100; 
eval_every=1;

% User's option
%------------------------------------------------
if nargin > 4  %  then paramstruct is an argument
    if isfield(option,'epoch')
        epoch = option.epoch; 
    end
    if isfield(option,'penalty')
        penalty = option.penalty; 
    end 
    if isfield(option,'loss')
        obj = option.loss; 
    end
end
% Aggregate full dataset (only for loss/criterion evaluation)
X = vertcat(Xdata{:});
Y = vertcat(Ydata{:});

loss=zeros(epoch, 1);
criterion=loss;
dist=loss;
% l1-norm proximal mapping operator
if matches(penalty,'l1')
    % l1-norm proximal mapping operator
    Proxg=@(z, lambda) sign(z).*max(abs(z)-lambda,0);
else
    % l2-norm proximal mapping operator
    Proxg=@(z, lambda) z/(lambda+1);
end



% display iteration information
fprintf('----------------Training with {%3.0d} workers------------\n', numWorkers);
fprintf('   round       primal loss     relative KKT residual   \n');
temp1=cell(numWorkers,1);
temp2=cell(numWorkers,1);
temp3=cell(numWorkers,1);
tempp=cell(numWorkers,1);
for k=1:numWorkers
    temp1{k}=zeros(d,1);
    temp2{k}=zeros(d,1);
    temp3{k}=zeros(d,1);
    tempp{k}=zeros(d,1);
end
for k = 1:numWorkers
    u_agent{k} = randn(d, 1); 
    temp3{k}=u_agent{k};
end
for k=1:numWorkers
    for j=1:numWorkers
        tempp{k}=tempp{k}+P(j,k)*u_agent{j};
    end
end
if matches(obj,'ls')
for k=1:numWorkers
    u_agent{k}=tempp{k}-alpha*(ls_gradient(Xdata{k},Ydata{k},u_agent{k}))/n;
    temp2{k}=u_agent{k};
    u_agent{k}=Proxg(u_agent{k},alpha/(numWorkers*n));
end
else
for k=1:numWorkers
    u_agent{k}=tempp{k}-alpha*(u_agent{k})/(n*numWorkers);
    temp2{k}=u_agent{k};
    u_agent{k}=prox_weighted_hinge(Xdata{k}, Ydata{k}, u_agent{k}, alpha, 1/n);
end
end

for i=1:epoch    
    % On All Agents:
    temp1=u_agent;
    for k=1:numWorkers
        tempu1{k}=zeros(d,1);
        tempu2{k}=zeros(d,1);
        for j=1:numWorkers
            tempu1{k}=tempu1{k}+A(j,k)*temp3{j};
            tempu2{k}=tempu2{k}+P(j,k)*u_agent{j};
        end
    end

    for k=1:numWorkers
        if matches(obj,'ls')
            u_agent{k}=tempu2{k}-tempu1{k}+temp2{k}-alpha*(ls_gradient(Xdata{k},Ydata{k},u_agent{k})-ls_gradient(Xdata{k},Ydata{k},temp3{k}))/n;
        else
            u_agent{k}=tempu2{k}-tempu1{k}+temp2{k}-alpha*(u_agent{k}-temp3{k})/(n*numWorkers);
        end
    end
    temp2=u_agent;
    temp3=temp1;
    for k=1:numWorkers
        if matches(obj,'ls')
             u_agent{k}=Proxg(u_agent{k},alpha/(numWorkers*n));
        else
            u_agent{k}=prox_weighted_hinge(Xdata{k}, Ydata{k}, u_agent{k}, alpha, 1/n);
        end
    end
    w=zeros(d,1);
    for k=1:numWorkers
        w=w+u_agent{k}/numWorkers;
    end
    if matches(obj,'ls')

        if matches(penalty,'l1')
            primal = norm(Y-(X*w))^2/(2*n)+lambda*norm(w, 1);
            kkt = norm(w-Proxg(w-X'*(X*w-Y)/n,lambda )); % relative KKT residual
        else
            primal = norm(Y-(X*w))^2/(2*n)+lambda/2*norm(w)^2;
            kkt = norm(w-Proxg(w-X'*(X*w-Y)/n,lambda )); % relative KKT residual
        end

    else

        if matches(penalty,'l1')
            primal=mean(max(0, 1-Y.*(X*w)) )+lambda*norm(w, 1);
            kkt=primal;
        else
            primal=mean(max(0, 1-Y.*(X*w)) )+lambda/2*norm(w)^2;
            kkt=primal;
        end

    end

    loss(i)=primal;
    criterion(i)=kkt;
    for k=1:numWorkers
        dist(i)=dist(i)+norm(opt-u_agent{k});
    end
    if mod(i,eval_every)==0        
        fprintf('%5d\t   %5d\t   %5d\t %5d\t \n', ...
                i, primal, kkt,dist(i));
    end    

end

end



function grad_sum = ls_gradient(X, y, w)
    predictions = X* w;         
    residuals = predictions - y; 
    grad_sum = X' * residuals;   
end

function alpha = compute_alpha_safe(Xdata, P)

    numWorkers = length(Xdata);
    Ls_list = zeros(numWorkers, 1);

    for i = 1:numWorkers
        Xi = Xdata{i};
        ni = size(Xi, 1);
        Ls_list(i) = norm(Xi' * Xi, 2) / ni;
    end

    Ls = max(Ls_list);

    W_tilde = (P + eye(numWorkers)) / 2;
    lambda_min = min(eig(W_tilde));

    alpha = 0.9 * 2 * lambda_min / Ls;
end
function u_opt = prox_weighted_hinge(X, y, v, tau, k)

    [n, d] = size(X);
    gamma = tau;

    H = zeros(d + n);
    H(1:d, 1:d) = (1/gamma) * eye(d);

    f = zeros(d + n, 1);
    f(1:d) = -(1/gamma) * v;
    f(d+1:end) = k * ones(n, 1);   

    A = zeros(2*n, d+n);
    b = zeros(2*n, 1);

    for i = 1:n
        A(i, 1:d) = -y(i) * X(i, :);
        A(i, d+i) = -1;
        b(i) = -1;
    end

    A(n+1:end, d+1:end) = -eye(n);
    b(n+1:end) = 0;

    opts = optimoptions('quadprog', ...
    'OptimalityTolerance', 1e-10, ...
    'ConstraintTolerance', 1e-10, ...
    'Display', 'None');
    z = quadprog(H, f, A, b, [], [], [], [], [], opts);

    u_opt = z(1:d);
end