clc; clear;

rng(42);

% data_file = '../data/a9a.mat'; 
% data_file = '../data/mushroom.mat';
data_file = '../data/phishing.mat';
% data_file = '../data/australian.mat';
% data_file = '../data/w2a.mat';

[~, dataset_name, ~] = fileparts(data_file);
load(data_file); 
[m, n] = size(X);
X = normc(full(X)')'; 
fprintf('%s, X=[%d,%d], y=[%d,1]\n', dataset_name, m, n, length(y));

max_iter = 25000; 
lambda_val = 0.01;
infeas_tol = 1e-4;
bs1 = 32 ; 
opt = "storm";  


if opt == "polyak"
    beta = 1;                            
    alpha = 0.905;
    rho = 0.4 * max_iter^(1/4);                   
    mu = 3 * max_iter^(-1/2);
elseif opt == "storm"
    beta = 8/5;
    alpha = 0.9;
    rho = 0.3 * max_iter^(1/3); 
    mu  =  1 * max_iter^(-1/3); 
    fprintf('  mu = %.3f, beta = %.3f, rho = %.3f\n', mu, beta, rho);
end

eta = mu * lambda_val;

% Dynamic sample budget function: ensure completion of maximum iterations
if opt == "storm"
    % storm algorithm: bs0 samples first, then 2*bs1 samples each time
    bs0 = round(max_iter^(1/3));
    max_samples = bs0 + (max_iter - 1) * 2 * bs1 + 1000;  % Add 1000 as buffer
elseif opt == "polyak"
    bs0 = bs1;
    % polyak algorithm: bs1 samples each time
    max_samples = max_iter * bs1 + 1000;  % Add 1000 as buffer
end
x_range = 22000;  % Plot X-axis range



%% Run algorithm
% Add s2meal subdirectory to path
addpath('../code/algos/s2meal');

if opt == "storm"
    [x, obj_vals, infeas_vals, kkt_vals, sample_nums] = s2meal_storm(X, y, lambda_val, mu, beta, alpha, rho, eta, ...
    max_iter, max_samples, infeas_tol, bs0, bs1);
elseif opt == "polyak"
    [x, obj_vals, infeas_vals, kkt_vals, sample_nums] = s2meal_polyak(X, y, lambda_val, mu, beta, alpha, rho, eta, ...
        max_iter, max_samples, infeas_tol, bs1);
end

%% Results output
fprintf('Completed! Final objective: %.6f, Violation: %.2e, KKT residual: %.2e\n', ...
    obj_vals(end), infeas_vals(end), kkt_vals(end));
fprintf('Total samples: %d, Iterations: %d\n', sample_nums(end), length(obj_vals));


%% Plotting
label_str = sprintf('s2meal-%s: μ=%.3f, β=%.2f, ρ=%.2f (obj=%.4f)', ...
    upper(opt), mu, beta, rho, obj_vals(end));

% Figure 1: Objective value
figure(1)
hold on;
h1 = plot(sample_nums, obj_vals, '-', 'LineWidth', 2);
h1.DisplayName = label_str;
xlim([0 x_range]);
xlabel('# of stochastic oracle calls');
ylabel('objective value');
title('Objective Value - Parameter Comparison');
grid on;
legend('Location', 'best');

% Figure 2: Constraint violation
figure(2)
hold on;
h2 = semilogy(sample_nums, infeas_vals, '-', 'LineWidth', 2);
h2.DisplayName = label_str;
xlim([0 x_range]);
xlabel('# of stochastic oracle calls');
ylabel('violation');
title('Constraint Violation - Parameter Comparison');
grid on;
legend('Location', 'best');

% Figure 3: KKT residual
figure(3)
hold on;
h3 = semilogy(sample_nums, kkt_vals, '-', 'LineWidth', 2);
h3.DisplayName = label_str;
xlim([0 x_range]);
xlabel('# of stochastic oracle calls');
ylabel('KKT Residual');
title('KKT Residual - Parameter Comparison');
grid on;
legend('Location', 'best');

%% Save experiment results
addpath(fullfile(fileparts(mfilename('fullpath')), '..', 'plot'));
save_experiment(dataset_name, 's2meal', opt, obj_vals, infeas_vals, kkt_vals, sample_nums, ...
    x, m, n, max_iter, lambda_val, alpha, mu, beta, rho, eta, bs0, bs1, max_samples, infeas_tol);