function [x, obj_vals, infeas_vals, kkt_vals, sample_nums] = s2meal_polyak(X, y, lambda_val, mu, beta, alpha, rho, eta, max_iter, max_samples, infeas_tol, bs1)
    
    [m, n] = size(X);
    
    % Try to load SALM warmstart point
    if exist('salm_warmstart.mat', 'file')
        load('salm_warmstart.mat', 'x');
        fprintf('Using SALM warmstart point as initial: norm(x) = %.6f\n', norm(x));
        delete('salm_warmstart.mat');  % Delete after use to avoid future misuse
    else
        x = randn(n, 1);
        x = x / norm(x);
        fprintf('Using random initial point\n');
    end

    dk = zeros(n, 1);
    z = zeros(n, 1);
    oracle_num = 0;
    mu_old = mu;
    
    % ===== Learning rate decay switch =====
    use_decay = true;  % true: enable periodic restart+decay; false: fixed step size
    % ======================================
    
    I_K = max_iter;
    objective_value_records = zeros(I_K+1, 1);
    infeasibility_records = zeros(I_K+1, 1);
    kkt_records = zeros(I_K+1, 1);
    sam2 = zeros(4, I_K);  % Changed to 4 rows
    tmp = 0;
    
    if use_decay
        fprintf('Starting training (learning rate strategy: 5000 restart + step decay)...\n');
    else
        fprintf('Starting training (learning rate strategy: fixed step size)...\n');
    end
    addpath('./opt/');  % Ensure compute_kkt_residual is accessible
    
    for k = 1:max_iter
        dk_old = dk;
        
        % Polyak momentum acceleration
        idx = randperm(m, bs1);
        fk = grad_f(x, X(idx,:), y(idx));
        dk = alpha * fk + (1 - alpha) * dk_old;
        oracle_num = oracle_num + bs1;
        
        % Learning rate adjustment
        if use_decay
            % 5000 restart + decay by 0.95 every 1000 iterations
            k_cycle = mod(k - 1, 5000) + 1;
            mu = mu_old * (0.95 ^ floor(k_cycle / 1000));
        end
        
        % Compute total direction vector
        D_k = dk + grad_c(x, rho);
      
        % S2MEAL update step (using updated mu)
        x_new = prox_h(z - mu * D_k, mu * lambda_val);
        z_new = z - beta * (prox_g(z, mu * lambda_val) - x_new);
        x = x_new;
        z = z_new;
        
        % Compute metrics
        obj_val = obj_func(x, X, y, lambda_val);
        infeas = abs(norm(x)^2 - 1);
        kkt_val = compute_kkt_residual(x, z, X, y, lambda_val, rho, eta);
        
        % Moving average records
        if k == 1
            objective_value_records(k) = obj_val;
            infeasibility_records(k) = infeas;
            kkt_records(k) = kkt_val;
        else
            objective_value_records(k) = (objective_value_records(k-1)*(k-1) + obj_val)/k;
            infeasibility_records(k) = (infeasibility_records(k-1)*(k-1) + infeas)/k;
            kkt_records(k) = (kkt_records(k-1)*(k-1) + kkt_val)/k;
        end
        
        % Record to sam2
        tmp = tmp + 1;
        sam2(1,tmp) = oracle_num;
        sam2(2,tmp) = objective_value_records(k);
        sam2(3,tmp) = infeasibility_records(k);
        sam2(4,tmp) = kkt_records(k);
        
        % Stop condition check
        if oracle_num > max_samples
            k1 = tmp;
            break;
        end
        
        % Output information
        if mod(k, 100) == 0 || k <= 5
            fprintf('Iter %4d: Obj=%.6f, Infeas=%.2e, KKT=%.2e, Samples=%d\n', ...
                k, objective_value_records(k), infeasibility_records(k), kkt_records(k), oracle_num);
        end
    end
    
    % Normal completion
    if k == max_iter
        k1 = tmp;
    end
    
    % Return valid data
    obj_vals = sam2(2,1:k1)';
    infeas_vals = sam2(3,1:k1)';
    kkt_vals = sam2(4,1:k1)';
    sample_nums = sam2(1,1:k1)';
end

function val = obj_func(w, X, y, lambda_val)
    Xw = X * w;
    loss = mean(log(1 + exp(-y .* Xw)));
    l1_norm = norm(w, 1);
    l2_norm = norm(w, 2);
    val = loss + lambda_val * (l1_norm - l2_norm);
end

function grad = grad_f(w, X, y)
    aa = y .* (X * w);
    prob = 1 ./ (1 + exp(-aa));
    b = y .* (1 - prob);
    grad = -X' * b / size(X, 1);
end

function grad = grad_c(w, rho)
    grad = rho * (norm(w)^2 - 1) * w;
end

function result = prox_h(u, eta)
    result = sign(u) .* max(abs(u) - eta, 0);
end

function result = prox_g(u, eta)
    norm_u = norm(u);
    if norm_u == 0
        result = zeros(size(u));
    else
        result = max(1 - eta / norm_u, 0) * u;
    end
end