function [w, loss, criterion,dist] = NIDS(Xdata, Ydata, lambda,alpha, option,P,opt)

    rng('default');
    numWorkers=length(Xdata); 
    x_agent=cell(numWorkers,1);
    z_agent=cell(numWorkers,1);
    A=zeros(numWorkers,numWorkers);
    a=cell(numWorkers,1);

    %A=(P+eye(numWorkers))/2;
    n=0;
    [~,d]=size(Xdata{1});

    for k=1:numWorkers
        x_agent{k}=zeros(d,1);
        z_agent{k}=zeros(d,1);
    end
    for k=1:numWorkers
        [temp,~]=size(Xdata{k});
        n=n+temp;
    end
    Lmax = 0;
    for i=1:numWorkers
        Xi = Xdata{i};
        Li = norm(full(Xi), 2)^2 / n;
        Lmax = max(Lmax, Li);
    end

    for i=1:numWorkers
        a{i}=alpha;
    end

    c=1/(2*alpha);

    for i=1:numWorkers
        for j=1:numWorkers
            if i==j
                A(i,j)=1-c*a{i}+c*a{i}*P(i,j);
            else
                A(i,j)=c*a{i}*P(i,j);
            end
        end
    end
    epoch=100; 
    eval_every=1;

    % User's option
    %------------------------------------------------
    if nargin > 4  %  then paramstruct is an argument
        if isfield(option,'epoch')
            epoch = option.epoch; 
        end
        if isfield(option,'penalty')
            penalty = option.penalty; 
        end 
        if isfield(option,'loss')
            obj = option.loss; 
        end
    end
    % Aggregate full dataset (only for loss/criterion evaluation)
    X = vertcat(Xdata{:});
    Y = vertcat(Ydata{:});

    loss=zeros(epoch, 1);
    criterion=loss;
    dist=loss;
    % l1-norm proximal mapping operator
    if matches(penalty,'l1')
        % l1-norm proximal mapping operator
        Proxg=@(z, lambda) sign(z).*max(abs(z)-lambda,0);
    else
        % l2-norm proximal mapping operator
        Proxg=@(z, lambda) z/(lambda+1);
    end



    % display iteration information
    fprintf('----------------Training with {%3.0d} workers------------\n', numWorkers);
    fprintf('   round       primal loss     relative KKT residual   \n');
    temp1=cell(numWorkers,1);
    tempx=cell(numWorkers,1);
    for k=1:numWorkers
        temp1{k}=zeros(d,1);
    end
    for k = 1:numWorkers
        x_agent{k} = randn(d, 1); 
        tempx{k}=x_agent{k};
    end

    if matches(obj,'ls')
        for k=1:numWorkers
            z_agent{k}=x_agent{k}-a{k}*(ls_gradient(Xdata{k},Ydata{k},x_agent{k})/n);
            x_agent{k}=Proxg(z_agent{k},a{k}/(numWorkers*n));
        end
    else
        for k=1:numWorkers
            z_agent{k}=x_agent{k}-a{k}*(x_agent{k})/(n*numWorkers);
            x_agent{k}=prox_weighted_hinge(Xdata{k}, Ydata{k}, z_agent{k}, a{k}, 1/n);
        end

    end

    for i=1:epoch  
        if matches(obj,'ls') 
            for k=1:numWorkers
                temp1{k}=zeros(d,1);
                for j=1:numWorkers
                    temp1{k}=temp1{k}+A(k,j)*(2*x_agent{j}-tempx{j}-a{j}*(ls_gradient(Xdata{j},Ydata{j},x_agent{j})-ls_gradient(Xdata{j},Ydata{j},tempx{j}))/n);
                end
            end
            for k=1:numWorkers
                
                z_agent{k}=z_agent{k}-x_agent{k}+temp1{k};
                %x_agent{k}=tempp{k}-alpha*(ls_gradient(Xdata{k},Ydata{k},x_agent{k}))/n;
                %temp2{k}=x_agent{k};
                tempx{k}=x_agent{k};
                x_agent{k}=Proxg(z_agent{k},a{k}/(numWorkers*n));
            end
        else
            for k=1:numWorkers
                temp1{k}=zeros(d,1);
                for j=1:numWorkers
                    temp1{k}=temp1{k}+A(k,j)*(2*x_agent{j}-tempx{j}-a{j}*(x_agent{j}-tempx{j})/(n*numWorkers));
                end
            end
            for k=1:numWorkers      
                z_agent{k}=z_agent{k}-x_agent{k}+temp1{k};
                %x_agent{k}=tempp{k}-alpha*(ls_gradient(Xdata{k},Ydata{k},x_agent{k}))/n;
                %temp2{k}=x_agent{k};
                tempx{k}=x_agent{k};
                x_agent{k}=prox_weighted_hinge(Xdata{k}, Ydata{k}, z_agent{k}, a{k}, 1/n);
            end
        end
        w=zeros(d,1);
        for k=1:numWorkers
            w=w+x_agent{k}/numWorkers;
        end
        if matches(obj,'ls')

            if matches(penalty,'l1')
                primal = norm(Y-(X*w))^2/(2*n)+lambda*norm(w, 1);
                kkt = norm(w-Proxg(w-X'*(X*w-Y)/n,lambda )); % relative KKT residual
            else
                primal = norm(Y-(X*w))^2/(2*n)+lambda/2*norm(w)^2;
                kkt = norm(w-Proxg(w-X'*(X*w-Y)/n,lambda )); % relative KKT residual
            end

        else

            if matches(penalty,'l1')
                primal=mean(max(0, 1-Y.*(X*w)) )+lambda*norm(w, 1);
                kkt=primal;
            else
                primal=mean(max(0, 1-Y.*(X*w)) )+lambda/2*norm(w)^2;
                kkt=primal;
            end

        end

        loss(i)=primal;
        criterion(i)=kkt;
        for k=1:numWorkers
            dist(i)=dist(i)+norm(opt-x_agent{k});
        end
        if mod(i,eval_every)==0        
            fprintf('%5d\t   %5d\t   %5d\t %5d\t \n', ...
                    i, primal, kkt,dist(i));
        end    

    end
end



function grad_sum = ls_gradient(X, y, w)
    predictions = X* w;         
    residuals = predictions - y; 
    grad_sum = X' * residuals;   
end

function alpha = compute_alpha_safe(Xdata, P)

    numWorkers = length(Xdata);
    Ls_list = zeros(numWorkers, 1);

    for i = 1:numWorkers
        Xi = Xdata{i};
        ni = size(Xi, 1);
        Ls_list(i) = norm(Xi' * Xi, 2) / ni;
    end

    Ls = max(Ls_list);

    W_tilde = (P + eye(numWorkers)) / 2;
    lambda_min = min(eig(W_tilde));

    alpha = 0.9 * 2 * lambda_min / Ls;
end
function u_opt = prox_weighted_hinge(X, y, v, tau, k)

    [n, d] = size(X);
    gamma = tau;

    H = zeros(d + n);
    H(1:d, 1:d) = (1/gamma) * eye(d);

    f = zeros(d + n, 1);
    f(1:d) = -(1/gamma) * v;
    f(d+1:end) = k * ones(n, 1);   

    A = zeros(2*n, d+n);
    b = zeros(2*n, 1);

    for i = 1:n
        A(i, 1:d) = -y(i) * X(i, :);
        A(i, d+i) = -1;
        b(i) = -1;
    end

    A(n+1:end, d+1:end) = -eye(n);
    b(n+1:end) = 0;

    opts = optimoptions('quadprog', ...
    'OptimalityTolerance', 1e-10, ...
    'ConstraintTolerance', 1e-10, ...
    'Display', 'None');
    z = quadprog(H, f, A, b, [], [], [], [], [], opts);

    u_opt = z(1:d);
end
