% runExp2_varyD.m
% — Experiment 2: fix n & bits, vary d

clear; close all; rng(0);
set(groot, ...
  'defaultAxesFontSize',16, ...
  'defaultLineLineWidth',2);

%% Settings
reps     = 100;
betaVals = [0.5];
bits     = 8;          % your B
n        = 5000;
dList    = [100,200,300,400,500];

% renamed SR→LQ, MR→NLQ for consistency
methods  = { ...
  'standard',  'standard_LQ',  'standard_NLQ', 'batched',  'batched_LQ',  'batched_NLQ' ...
};

%% Run & collect
errAll = nan(numel(methods), reps, numel(dList), numel(betaVals));
for bIdx = 1:numel(betaVals)
  lambda = betaVals(bIdx);
  fprintf('=== Starting λ = %.1f ===\n', lambda);
  for i = 1:numel(dList)
    d = dList(i);
    fprintf('  -> d = %d\n', d);

    [X, Sigma, v1, lambda1, lambda2] = generateData(n, d, lambda);
    [deltau, deltae, alphae] = pickQuantParams(X, bits);
    fprintf('     quant params: δ_uni=%.3e, δ_exp=%.3e, α=%.3e\n', ...
            deltau, deltae, alphae);

    w0  = randn(d,1); w0 = w0/norm(w0);
    eta = 2*log(n)/(n*(lambda1 - lambda2));
    B   = 10;

    for rep = 1:reps
      fprintf('    rep %d/%d\n', rep, reps);
      for m = 1:numel(methods)
        method = methods{m};
        fprintf('      method: %s ... ', method);

        switch method
          case 'standard'
            [~,errs] = regularOja(           X, w0, eta,     v1);
          case 'standard_LQ'
            [~,errs] = regularOja_lowPrecision( ...
                       X, w0, eta, v1, @stochasticRound, deltau, []);
          case 'standard_NLQ'
            [~,errs] = regularOja_lowPrecision( ...
                       X, w0, eta, v1, @mantissaRound,   deltae, alphae);
          case 'batched'
            [~,errs] = batchedOja(           X, w0, eta*(n/B), v1, B);
          case 'batched_LQ'
            [~,errs] = batchedOja_lowPrecision( ...
                       X, w0, eta*(n/B), v1, @stochasticRound, deltau, [], B);
          case 'batched_NLQ'
            [~,errs] = batchedOja_lowPrecision( ...
                       X, w0, eta*(n/B), v1, @mantissaRound,   deltae, alphae, B);
        end

        errAll(m,rep,i,bIdx) = errs(end);
        fprintf('done (err=%.3e)\n', errs(end));
      end
    end
  end
end

%% Save & plot
fprintf('Saving resultsExp2.mat and plotting Exp2...\n');
save('resultsExp2.mat', ...
     'errAll','methods','dList','betaVals','reps','bits','n');

% Compute mean & SE
mu = squeeze(mean(errAll,2));    % [methods × dList × betaVals]
se = squeeze(std(errAll,[],2)/sqrt(reps));

% Line+marker styles
% lineStyles = {'-o','--s',':d','-.^','-x','--*'};
lineStyles = { '-','-','-','-.','-.','-.' };
markers = {'.','.','.','d','d','d'};

% Create figure
figure('Name','Error vs d','NumberTitle','off','Color','w');
for bIdx = 1:numel(betaVals)
    ax = subplot(1,numel(betaVals),bIdx);
    hold(ax,'on');

    % Grid styling
    grid(ax,'on'); grid(ax,'minor');
    ax.GridLineStyle     = '--';
    ax.MinorGridLineStyle= ':';
    ax.GridAlpha         = 0.7;
    ax.MinorGridAlpha    = 0.4;

    % Axis text styling
    ax.FontSize    = 25;
    ax.FontWeight  = 'bold';
    ax.LineWidth   = 1.5;
    ax.XScale      = 'log';
    ax.YScale      = 'log';

    % Plot each method
    for m = 1:numel(methods)
        errorbar(ax, dList, ...
                 mu(m,:,bIdx), ...
                 se(m,:,bIdx), ...
                 lineStyles{m}, ...
                 'DisplayName',methods{m}, ...
                 'LineWidth',3, ...
                 'MarkerSize',8, ...
                 'Marker', markers{m});
    end

    % Labels & title
    xlabel(ax, 'd (dimension)',       'FontSize',25,'FontWeight','bold');
    ylabel(ax, 'Final sin^2-error',   'FontSize',25,'FontWeight','bold');
    title(ax, sprintf('\\lambda = %.1f, bits = %d, n = %d', ...
                      betaVals(bIdx), bits, n), ...
          'FontSize',25,'FontWeight','bold');

    legend(ax,'Interpreter','none','Location','best','FontSize',25);
    hold(ax,'off');
end