% runExp4_compareOverTime.m
% — Experiment 4: fix n, d, lambda, bits; compare all 6 algorithms over time

clear; close all; rng(0);
set(groot, ...
  'defaultAxesFontSize',16, ...
  'defaultLineLineWidth',1);

%% Settings
reps     = 10;            % # repetitions for averaging
lambda   = 1;             % 1lambda” parameter
bits     = 8;             % quantization bits
d        = 500;           % data dimension
n        = 2000;          % number of samples
B        = 10;            % number of batches for batched methods

methods  = { ...
  'standard',  'standard_LQ',  'standard_NLQ', 'batched',  'batched_LQ',  'batched_NLQ' ...
};
lineStyles = { '-','-','-','-.','-.','-.' };
%lineStyles = { 'g-','b-','r-','g-.','b-.','r-.' };
markers = {'.','.','.','d','d','d'};
%% Generate one dataset (same X for all reps)
[X, Sigma, v1, lambda1, lambda2] = generateData(n, d, lambda);
[deltau, deltae, alphae] = pickQuantParams(X, bits);

%% Preallocate history storage
histories = cell(numel(methods),1);
for m = 1:numel(methods)
    if m>3
        histories{m} = nan(reps, B);
    else
        histories{m} = nan(reps, n);
    end
end

%% Run all methods, collect sin²-error over time
for rep = 1:reps
    fprintf('=== rep %d/%d ===\n', rep, reps);
    w0 = randn(d,1); w0 = w0/norm(w0);
    eta = 2*log(n)/(n*(lambda1 - lambda2));
    
    for m = 1:numel(methods)
        method = methods{m};
        fprintf('  Running %s...\n', method);
        switch method
            case 'standard'
            [~, errs] = regularOja(            X, w0,               eta, v1);
         
          case 'standard_LQ'
            [~, errs] = regularOja_lowPrecision( ...
                          X, w0, eta, v1, @stochasticRound, deltau, []);
          case 'standard_NLQ'
            [~, errs] = regularOja_lowPrecision( ...
                          X, w0, eta, v1, @mantissaRound,   deltae, alphae);
            case 'batched'
            [~, errs] = batchedOja(            X, w0, eta*(n/B),     v1, B);
          case 'batched_LQ'
            [~, errs] = batchedOja_lowPrecision( ...
                          X, w0, eta*(n/B), v1, @stochasticRound, deltau, [], B);
          case 'batched_NLQ'
            [~, errs] = batchedOja_lowPrecision( ...
                          X, w0, eta*(n/B), v1, @mantissaRound,   deltae, alphae, B);
        end
       
        histories{m}(rep,:) = errs;
    end
end

%% Compute mean & stderr over reps
means = cellfun(@(H) mean(H,1), histories, 'Uni',false);
stderrs = cellfun(@(H) std(H,[],1)/sqrt(reps), histories, 'Uni',false);

%% Plot: error vs processed samples
figure('Name','Error vs processed samples','NumberTitle','off','Color','w');
ax = gca; hold(ax,'on');

% Grid styling
grid(ax,'on'); grid(ax,'minor');
ax.GridLineStyle     = '--';
ax.MinorGridLineStyle= ':';
ax.GridAlpha         = 0.7;
ax.MinorGridAlpha    = 0.4;

% Axis styling
ax.FontSize    = 25;
ax.FontWeight  = 'bold';
ax.LineWidth   = 1;

% Scales
ax.XScale = 'linear';
ax.YScale = 'log';

% Plot each method
for m = 1:6
    if m>3
        t = (1:B) * (n/B);       % samples processed at each batch
    else
        t = 1:n;                 % samples processed at each step
    end
    y = means{m};
    plot(ax, t, y, lineStyles{m}, ...
        'DisplayName', methods{m}, ...
        'LineWidth', 3, ...
        'Marker', markers{m},...
        'MarkerSize', 8);
end

% Labels & title
xlabel(ax, 'Samples processed',       'FontSize',25,'FontWeight','bold');
ylabel(ax, 'Mean sin^2-error',       'FontSize',25,'FontWeight','bold');
title(ax, sprintf('\\lambda = %.1f, \\beta = %d, d = %d, n = %d', ...
      lambda, bits, d, n), ...
      'FontSize',25,'FontWeight','bold');

legend(ax, 'Interpreter','none', 'Location','best','FontSize',20);
hold(ax,'off');