function kkt_res = compute_kkt_residual(w, z, X, y, lambda_val, rho, eta)
% Compute KKT residual in expectation sense
% 
% Inputs:
%   w          - current primal variable (n×1)
%   z          - current auxiliary variable (n×1)
%   X          - feature matrix (m×n)
%   y          - label vector (m×1)
%   lambda_val - regularization parameter
%   rho        - penalty parameter
%   eta        - proximal parameter
%
% Output:
%   kkt_res    - scalar KKT residual

    %% 1. Compute expected gradient 𝔼[∇f(w)] (using full dataset)
    m = size(X, 1);
    Xw = X * w;
    sigmoid_vals = 1 ./ (1 + exp(-y .* Xw));
    grad_f = -X' * (y .* (1 - sigmoid_vals)) / m;
    
    
    %% 2. Compute constraint gradient ∇c(w)·λ
    constraint_val = norm(w)^2 - 1;
    lambda_multiplier = rho * constraint_val;
    grad_constraint = 2 * lambda_multiplier * w;
    
    
    %% 3. Compute ∂h(w) (subgradient of L1 norm)
    % Through proximal mapping: ∂h(w) = (w - prox_h(w))/η
    w_prox_h = sign(w) .* max(abs(w) - lambda_val * eta, 0);
    subgrad_h = (w - w_prox_h) / eta;
    
    
    %% 4. Compute ∂g(y) (subgradient of L2 norm)
    % y = prox_g(z)
    norm_z = norm(z);
    if norm_z < 1e-10
        y_prox_g = zeros(size(z));
        subgrad_g = zeros(size(w));
    else
        y_prox_g = max(1 - lambda_val * eta / norm_z, 0) * z;
        if norm(y_prox_g) < 1e-10
            subgrad_g = zeros(size(w));
        else
            subgrad_g = lambda_val * (y_prox_g / norm(y_prox_g));
        end
    end
    
    
    %% 5. Stationarity residual: ‖∇f + ∇c·λ + ∂h - ∂g‖
    lagrangian_grad = grad_f + grad_constraint + subgrad_h - subgrad_g;
    stationarity_res = norm(lagrangian_grad);
    
    
    %% 6. Constraint violation: |‖w‖² - 1|
    feasibility_res = abs(constraint_val);
    
    
    %% 7. DC proximal distance: ‖w - y‖
    proximity_res = norm(w - y_prox_g);
    
    
    %% 8. Combined KKT residual (take maximum)
    kkt_res = max([stationarity_res, feasibility_res, proximity_res]);
    
end