% runExp1_varyN.m
% — Experiment 1: fix d & bits, vary n

clear; close all; rng(2);
set(groot, ...
  'defaultAxesFontSize',16, ...
  'defaultLineLineWidth',2);

%% Settings
reps     = 10;
betaVals = [1];
bits     = 8;               % your B
d        = 200;
nList    = round(linspace(1000,5000,5));

% renamed: SR→LQ, MR→NLQ
methods  = { ...
  'standard',  'standard_LQ',  'standard_NLQ', 'batched',  'batched_LQ',  'batched_NLQ' ...
};

%% Run & collect
errAll = nan(numel(methods), reps, numel(nList), numel(betaVals));
for bIdx = 1:numel(betaVals)
  lambda = betaVals(bIdx);
  fprintf('=== Starting λ = %.1f ===\n', lambda);
  for i = 1:numel(nList)
    n = nList(i);
    fprintf('  -> n = %d\n', n);

    [X, Sigma, v1, lambda1, lambda2] = generateData(n,d,lambda);
    [deltau, deltae, alphae]       = pickQuantParams(X,bits);
    fprintf('     quant params: δ_uni=%.3e, δ_exp=%.3e, α=%.3e\n', ...
            deltau, deltae, alphae);

    w0  = randn(d,1); w0 = w0/norm(w0);
    eta = 2*log(n)/(n*(lambda1 - lambda2));
    B   = 10;

    for rep = 1:reps
      fprintf('    rep %d/%d\n', rep, reps);
      for m = 1:numel(methods)
        method = methods{m};
        fprintf('      method: %s ... ', method);

        switch method
          case 'standard'
            [~,errs] = regularOja(           X, w0, eta,     v1);
          case 'standard_LQ'
            [~,errs] = regularOja_lowPrecision( ...
                       X, w0, eta, v1, @stochasticRound, deltau, []);
          case 'standard_NLQ'
            [~,errs] = regularOja_lowPrecision( ...
                       X, w0, eta, v1, @mantissaRound,   deltae, alphae);
          case 'batched'
            [~,errs] = batchedOja(           X, w0, eta*(n/B), v1, B);
          case 'batched_LQ'
            [~,errs] = batchedOja_lowPrecision( ...
                       X, w0, eta*(n/B), v1, @stochasticRound, deltau, [], B);
          case 'batched_NLQ'
            [~,errs] = batchedOja_lowPrecision( ...
                       X, w0, eta*(n/B), v1, @mantissaRound,   deltae, alphae, B);
        end

        errAll(m,rep,i,bIdx) = errs(end);
        fprintf('done (err=%.3e)\n', errs(end));
      end
    end
  end
end

%% Save & plot
fprintf('Saving resultsExp1.mat and plotting Exp1...\n');
save('resultsExp1.mat', ...
     'errAll','methods','nList','betaVals','reps','bits','d');

mu = squeeze(mean(errAll,2));   % [methods × nList × betaVals]
se = squeeze(std(errAll,[],2)/sqrt(reps));

% lineStyles = {'-o','--s',':d','-.^','-x','--*'};
lineStyles = { '-','-','-','-.','-.','-.' };
markers = {'.','.','.','d','d','d'};

figure('Name','Error vs n','NumberTitle','off','Color','w');
for bIdx = 1:numel(betaVals)
    ax = subplot(1,numel(betaVals),bIdx); hold(ax,'on');
    grid(ax,'on'); grid(ax,'minor');
    ax.GridLineStyle     = '--';
    ax.MinorGridLineStyle= ':';
    ax.GridAlpha         = 0.7;
    ax.MinorGridAlpha    = 0.4;
    ax.FontSize    = 25;
    ax.FontWeight  = 'bold';
    ax.LineWidth   = 1.5;
    ax.XScale      = 'log';
    ax.YScale      = 'log';

    for m = 1:numel(methods)
        errorbar(ax, nList, mu(m,:,bIdx), se(m,:,bIdx), ...
                 lineStyles{m}, ...
                 'DisplayName',methods{m}, ...
                 'LineWidth',3, ...
                 'MarkerSize',8, ...
                 'Marker', markers{m});
    end

    xlabel(ax, 'n (samples)', 'FontSize',25,'FontWeight','bold');
    ylabel(ax, 'Final sin^2-error', 'FontSize',25,'FontWeight','bold');
    title(ax, sprintf('\\lambda = %.1f, bits = %d, d = %d', ...
                      betaVals(bIdx), bits, d), ...
          'FontSize',25,'FontWeight','bold');
    legend(ax,'Interpreter','none','Location','best','FontSize',25);
    hold(ax,'off');
end