function [x, obj_vals, infeas_vals, kkt_vals, sample_nums] = s2meal_storm(X, y, lambda_val, mu, beta, alpha, rho, eta, max_iter, max_samples, infeas_tol, bs0, bs1)    
    [m, n] = size(X);
    
    x = randn(n, 1);
    x = x / norm(x);
    x_old = zeros(n, 1);
    dk = zeros(n, 1);
    z = zeros(n, 1);
    oracle_num = 0;
    

    I_K = max_iter;
    objective_value_records = zeros(I_K+1, 1);
    infeasibility_records = zeros(I_K+1, 1);
    sam2 = zeros(4, I_K);  % Changed to 4 rows
    kkt_records = zeros(I_K+1, 1);
    tmp = 0;  % Record index
    
    fprintf('Starting training...\n');
    
    addpath('./opt/'); 
    
    for k = 1:max_iter
        x_old = x;
        dk_old = dk;
        
        if k == 1
            idx = randperm(m, bs0);
            dk = grad_f(x, X(idx,:), y(idx));
            oracle_num = oracle_num + bs0;
        else
            idx = randperm(m, bs1);
            fk = grad_f(x, X(idx,:), y(idx));
            fk_old = grad_f(x_old, X(idx,:), y(idx));
            dk = fk + (1 - alpha) * (dk_old - fk_old);
            oracle_num = oracle_num + 2*bs1;
        end
        
        D_k = dk + grad_c(x, rho);
        
        x_new = prox_h(z - mu * D_k, eta);
        z_new = z - beta * (prox_g(z, eta) - x_new);
        x = x_new;
        z = z_new;
        
        % Objective function and constraint violation
        obj_val = obj_func(x, X, y, lambda_val);
        infeas = abs(norm(x)^2 - 1);
        kkt_val = compute_kkt_residual(x, z, X, y, lambda_val, rho, eta);
        
       if k == 1
        objective_value_records(k) = obj_val;
        infeasibility_records(k) = infeas;
        kkt_records(k) = kkt_val;
       else
            objective_value_records(k) = (objective_value_records(k-1)*(k-1) + obj_val)/k;
            infeasibility_records(k) = (infeasibility_records(k-1)*(k-1) + infeas)/k;
            kkt_records(k) = (kkt_records(k-1)*(k-1) + kkt_val)/k;
       end

        % Record to sam2
        tmp = tmp + 1;
        sam2(1,tmp) = oracle_num;
        sam2(2,tmp) = objective_value_records(k);
        sam2(3,tmp) = infeasibility_records(k);
        sam2(4,tmp) = kkt_records(k);
        
        % Stop condition check
        if oracle_num > max_samples
            k1 = tmp;
            break;
        end
    
        % Output information
       
        if mod(k, 100) == 0 || k <= 5
            fprintf('Iter %4d: Obj=%.6f, Infeas=%.2e, KKT=%.2e, Samples=%d\n', ...
                k, objective_value_records(k), infeasibility_records(k), kkt_records(k), oracle_num);
        end
    end
    
    % If loop completes normally
    if k == max_iter
        k1 = tmp;
    end
    
    % Return valid data
    obj_vals = sam2(2,1:k1)';
    infeas_vals = sam2(3,1:k1)';
    kkt_vals = sam2(4,1:k1)';
    sample_nums = sam2(1,1:k1)';
end

function val = obj_func(w, X, y, lambda_val)
    Xw = X * w;
    loss = mean(log(1 + exp(-y .* Xw)));
    l1_norm = norm(w, 1);
    l2_norm = norm(w, 2);
    val = loss + lambda_val * (l1_norm - l2_norm);
end

function val = penalty(w, rho)
    val = (rho/2) * (norm(w)^2 - 1)^2;
end

function grad = grad_f(w, X, y)
    aa = y .* (X * w);
    prob = 1 ./ (1 + exp(-aa));
    b = y .* (1 - prob);
    grad = -X' * b / size(X, 1);
end

function grad = grad_c(w, rho)
    grad = 2 * rho * (norm(w)^2 - 1) * w;
end

function result = prox_h(u, eta)
    result = sign(u) .* max(abs(u) - eta, 0);
end

function result = prox_g(u, eta)
    norm_u = norm(u);
    if norm_u == 0
        result = zeros(size(u));
    else
        result = max(1 - eta / norm_u, 0) * u;
    end
end