clear; clc; close all;

addpath(fullfile('..', 'data'));
load zipcode_64_pfail_2.mat

rng(123);
A = data.A;
b = data.b;
[m, n] = size(A);
maxiter = 400;
gamma = sqrt(maxiter * m);
beta = 0.0;
init_x = randn(n, 1) * 10;
% init_x = init_x / norm(init_x);
tol = data.bestloss * 1.5;
early_stop = true;
alpha_0 = 100;
show_info = true;

[sgdsol, sgdinfo] = proxsgdtaurobust(A, b, gamma, beta, init_x, maxiter, tol, ...
    early_stop, alpha_0, show_info, 0);
[sgdsolr, sgdinfor] = proxsgdtaurobust(A, b, gamma, beta, init_x, maxiter, tol, ...
    early_stop, alpha_0, show_info, 1);
[splsol, splinfo] = proxlintaurobust(A, b, gamma, beta, init_x, maxiter, tol, ...
    early_stop, alpha_0, show_info, 0);
[splsolr, splinfor] = proxlintaurobust(A, b, gamma, beta, init_x, maxiter, tol, ...
    early_stop, alpha_0, show_info, 1);

startepo = 30;
endepo = 35;
plot(sgdinfo.objs(1:m * endepo), 'LineWidth', 2);
hold on
plot(sgdinfor.objs(1:m * endepo), 'LineWidth', 2);
plot(splinfo.objs(1:m * endepo), 'LineWidth', 2);
plot(splinfor.objs(1:m * endepo), 'LineWidth', 2);
legend(["DSGD", "Safe-DSGD", "DSPL", "Safe-DSPL"]);

xlim([m * startepo, m * endepo]);
set(gca, 'FontSize', 20);
tightfig;

saveas(gca, "step_" + alpha_0 + "_zipcode_64_pfail_2" + ".fig");