% compareOjaVariants.m
% -------------------------------------------------------------------------
%   Runs and compares four variants of Oja’s algorithm on one synthetic
%   dataset:
%       1. regularOja               – online, full precision
%       2. regularOja_lowPrecision  – online, per‑sample quantisation
%       3. batchedOja               – B‑batch, full precision
%       4. batchedOja_lowPrecision  – B‑batch, Algorithm 2 quantisation
% -------------------------------------------------------------------------

clear; clc; close all;

%% ---------------------------  Data generation  ---------------------------

n    = 3000;   % samples
d    = 100;    % dimension
beta = 0.5;    % variance‑decay exponent
[X, ~, v1, ~, ~] = generateData(n, d, beta);

%% ----------------------  Algorithm hyper‑parameters  ---------------------

eta        = 1e-4;          % learning rate shared by all
numBatches = 25;            % B for batched variants

delta      = 0.01;          % quantisation step for low‑precision variants
eta_batch  = eta * (n / numBatches);  % scaled η for batch updates

% shared random unit vector
w0 = randn(d,1);  w0 = w0 / norm(w0);

%% -----------------------------  Run variants  ----------------------------

[~, err_reg]      = regularOja(X, w0, eta,       v1);
[~, err_reg_lp]   = regularOja_lowPrecision(X, w0, eta,       v1, delta);
[~, err_batch]    = batchedOja(X, w0, eta_batch, v1, numBatches);
[~, err_batch_lp] = batchedOja_lowPrecision(X, w0, eta_batch, v1, delta, numBatches);

%% ------------------------  Align x‑axes for plotting ---------------------

iters_reg   = 1:length(err_reg);
iters_batch = linspace(1, n, numBatches);

%% --------------------------------  Plot  ---------------------------------

figure('Color','w'); hold on;
plot(iters_reg,   err_reg,      'k-',  'LineWidth',1.4);
plot(iters_reg,   err_reg_lp,   'b--', 'LineWidth',1.4);
plot(iters_batch, err_batch,    'r-o', 'LineWidth',1.4, 'MarkerIndices',1:numBatches);
plot(iters_batch, err_batch_lp, 'm-s', 'LineWidth',1.4, 'MarkerIndices',1:numBatches);

xlabel('Processed samples');
ylabel('sin^2 error');
legend({'Regular Oja', 'Regular Oja (low‑prec)', ...
        'Batched Oja', 'Batched Oja (low‑prec)'}, ...
       'Location','northeast');

title(sprintf('Convergence comparison (n=%d, d=%d, B=%d, \\eta=%.0e, \\delta=%.3f)', ...
               n, d, numBatches, eta, delta), 'Interpreter','tex');

grid on; box on;

% Hide the axes toolbar to avoid the exportgraphics warning
ax = gca;
if isprop(ax, 'Toolbar')
    ax.Toolbar.Visible = 'off';
end
%% ----------------------------  Save figure  -----------------------------

outDir = fullfile(pwd, 'figures');  % absolute path avoids relative‑path issues
if ~exist(outDir, 'dir')
    [ok, msg] = mkdir(outDir);
    if ~ok
        warning('Could not create output directory: %s', msg);
        outDir = pwd;  % fall back to current dir
    end
end

stamp   = char(datetime('now', 'Format', 'yyyyMMdd_HHmmss'));
filePNG = fullfile(outDir, ['OjaComparison_' stamp '.png']);

try
    exportgraphics(gcf, filePNG, 'Resolution', 300);
    fprintf('Plot saved to %s\n', filePNG);
catch ME
    warning('exportgraphics failed: %s\nPlot was not saved.', ME.message);
end


%% -----------------------  Print final errors -----------------------------

fprintf('\nFinal sin^2 errors\n');
fprintf('  Regular Oja              : %.3e\n', err_reg(end));
fprintf('  Regular Oja (low‑prec)   : %.3e\n', err_reg_lp(end));
fprintf('  Batched Oja              : %.3e\n', err_batch(end));
fprintf('  Batched Oja (low‑prec)   : %.3e\n', err_batch_lp(end));
