mex_all;
clear;
%% Load Dataset
load 'a9a.mat';

%% Add Bias
X = [ones(size(X, 1), 1) X];
[N, Dim] = size(X);
X = full(X');

%% Normalize Data
sum1 = 1./sqrt(sum(X.^2, 1));
if abs(sum1(1) - 1) > 10^(-10)
    X = X.*repmat(sum1, Dim, 1);
end
clear sum1;

%% Set Params
passes = 300;
model = 'logistic'; 
regularizer = 'L2'; 
init_weight = repmat(0, Dim, 1);
mu = 0; % Non-strongly convex
L = (0.25 * max(sum(X.^2, 1)) + mu);
fprintf('Model: %s-%s\n', regularizer, model);

%% Run Algorithms
seeds = 5;

X_L2S = (0:passes)';
X_Acc_SVRG_G = (0:passes)';

hist_L2S = [];
hist_L2S_N = [];
hist_Acc_SVRG_G = [];

% L2S (N-independent)
algorithm = 'L2S';
step_size = 0.1 / L; % sometimes we observed this choice diverged.
for i = 1:seeds
    fprintf('Algorithm: %s, Seeds: %d\n', algorithm, i);
    tic;
    [time1, hist1] = Interface(X, y, algorithm, model, regularizer, init_weight, mu, L...
        , step_size, passes);
    time = toc;
    fprintf('Time: %f seconds \n', time);
    hist_L2S = [hist_L2S, hist1];
end

hist1 = [X_L2S, mean(hist_L2S, 2)];
std1 = std(hist_L2S, 0, 2);

% L2S (N-dependent)
algorithm = 'L2S';
step_size = 1 / (L * sqrt(N));
for i = 1:seeds
    fprintf('Algorithm: %s, Seeds: %d\n', algorithm, i);
    tic;
    [time2, hist2] = Interface(X, y, algorithm, model, regularizer, init_weight, mu, L...
        , step_size, passes);
    time = toc;
    fprintf('Time: %f seconds \n', time);
    hist_L2S_N = [hist_L2S_N, hist2];
end

hist2 = [X_L2S, mean(hist_L2S_N, 2)];
std2 = std(hist_L2S_N, 0, 2);

% Acc-SVRG-G (Theorem 4.2)
algorithm = 'Acc_SVRG_G';
for i = 1:seeds
    fprintf('Algorithm: %s, Seeds: %d\n', algorithm, i);
    tic;
    [time3, hist3] = Interface(X, y, algorithm, model, regularizer, init_weight, mu, L...
        , step_size, passes);
    time = toc;
    fprintf('Time: %f seconds \n', time);
    hist_Acc_SVRG_G = [hist_Acc_SVRG_G, hist3];
end

hist3 = [X_Acc_SVRG_G, mean(hist_Acc_SVRG_G, 2)];
std3 = std(hist_Acc_SVRG_G, 0, 2);

%% Plot
aa = max([max(hist1(:, 2)), max(hist2(:, 2)), max(hist3(:, 2))]);
b = 3;

figure(101);

set(gcf,'position',[520,500,436,269]);

% Plot L2S (N-independent)
curve1 = hist1(1:b:end,2) + std1(1:b:end);
curve2 = max(hist1(1:b:end,2) - std1(1:b:end), 1E-10);
x2 = [hist1(1:b:end,1)', fliplr(hist1(1:b:end,1)')];
inBetween = [curve1', fliplr(curve2')];
h = fill(x2, inBetween, 'b', 'LineStyle', 'none');
set(h, 'facealpha', 0.15);
h.Annotation.LegendInformation.IconDisplayStyle = 'off';
hold on;
plot(hist1(1:b:end,1), hist1(1:b:end,2), 'b--', 'linewidth', 1);

% Plot L2S (N-dependent)
curve1 = hist2(1:b:end,2) + std2(1:b:end);
curve2 = max(hist2(1:b:end,2) - std2(1:b:end), 1E-10);
x2 = [hist2(1:b:end,1)', fliplr(hist2(1:b:end,1)')];
inBetween = [curve1', fliplr(curve2')];
h = fill(x2, inBetween, 'r', 'LineStyle', 'none');
set(h, 'facealpha', 0.15);
h.Annotation.LegendInformation.IconDisplayStyle = 'off';
hold on;
plot(hist2(1:b:end,1), hist2(1:b:end,2), 'r-.', 'linewidth', 1);

% Plot Acc-SVRG-G
curve1 = hist3(1:b:end,2) + std3(1:b:end);
curve2 = max(hist3(1:b:end,2) - std3(1:b:end), 1E-10);
x2 = [hist3(1:b:end,1)', fliplr(hist3(1:b:end,1)')];
inBetween = [curve1', fliplr(curve2')];
h = fill(x2, inBetween, 'm', 'LineStyle', 'none');
set(h, 'facealpha', 0.15);
h.Annotation.LegendInformation.IconDisplayStyle = 'off';
hold on;
plot(hist3(1:b:end,1), hist3(1:b:end,2), 'm-', 'linewidth', 1);
hold off;
set(gca, 'YScale', 'log');

xlabel('Number of effective passes', 'Interpreter', 'latex');
ylabel('$||\nabla f (x)||$', 'Interpreter', 'latex');
axis([0, passes, 1E-8, aa]);
lgd = legend('L2S ($n$-independent)', 'L2S ($n$-dependent)', 'Acc-SVRG-G');
lgd.FontSize = 12;
