function [Z, out] = z_update_fista_simplex(K, J, Z0, opts)

    if nargin < 4, opts = struct(); end
    if ~isfield(opts,'maxit')     || isempty(opts.maxit),     opts.maxit = 50; end
    if ~isfield(opts,'tol')       || isempty(opts.tol),       opts.tol = 1e-3; end
    if ~isfield(opts,'backtrack') || isempty(opts.backtrack), opts.backtrack = true; end
    if ~isfield(opts,'restart')   || isempty(opts.restart),   opts.restart = true; end
    if ~isfield(opts,'verbose')   || isempty(opts.verbose),   opts.verbose = false; end

    m = size(J,1); n = size(K,1);
    K = (K + K') * 0.5;


    if ~isfield(opts,'L') || isempty(opts.L)
        L = 2*(normest(K) + 1e-8);
    else
        L = opts.L;
    end


    Z_old = Z0;
    Z_cur = Z0;
    t = 1;

    fval = @(Z) sum(sum((Z*K).*Z)) - 2*sum(sum(J.*Z));  % Tr(Z K Z') - 2 Tr(J Z')
    fvals = zeros(1, opts.maxit);
    relchg = NaN;

    inc = 1.5;          
    blend = 0.9;       

    for it = 1:opts.maxit
      
        t_next = (1 + sqrt(1 + 4*t^2)) / 2;
        Y = Z_cur + ((t - 1)/t_next) * (Z_cur - Z_old);

        % ∇f(Y) = 2 Y K - 2 J
        Grad = 2*(Y*K) - 2*J;


        if opts.backtrack
            Ltry = L;
            while true
                Z_tilde = Y - (1/Ltry) * Grad;
                Z_new = proj_simplex_cols(Z_tilde);    

                
                lhs = fval(Z_new);
                D   = Z_new - Y;
                rhs = fval(Y) + sum(sum(Grad.*D)) + (Ltry/2)*norm(D,'fro')^2;

                if lhs <= rhs, break; else, Ltry = Ltry * inc; end
            end
       
            L = blend*L + (1-blend)*Ltry;
        else
            Z_tilde = Y - (1/L) * Grad;
            Z_new = proj_simplex_cols(Z_tilde);
            lhs = fval(Z_new); 
        end


        if opts.restart
            if sum(sum((Z_new - Z_cur).*(Z_cur - Z_old))) > 0
                t_next = 1;    
                Y = Z_cur;
            end
        end

    
        relchg = norm(Z_new - Z_cur, 'fro') / max(1, norm(Z_cur, 'fro'));
        fvals(it) = lhs;

      
        Z_old = Z_cur;
        Z_cur = Z_new;
        t = t_next;

        if opts.verbose && mod(it,10)==0
            fprintf('[Z-FISTA] it=%3d  relchg=%.2e  L=%.3g  f=%.6g\n', it, relchg, L, lhs);
        end

        if relchg < opts.tol
            fvals = fvals(1:it);
            break;
        end
    end

    Z = Z_cur;
    if nargout > 1
        out = struct('iters', it, 'relchg', relchg, 'L', L, ...
                     'fval_last', fvals(max(1,numel(fvals))), 'fvals', fvals);
    end
end


function Zp = proj_simplex_cols(Z)
    [m,n] = size(Z);
    Zp = zeros(m,n);
    for j = 1:n
        Zp(:,j) = EProjSimplex_new(Z(:,j)')';
    end
end