% % runExp3_varyBits.m
% % — Experiment 3: fix n & d, vary bits → Error vs δ (with LQ / NLQ naming & styling)
% 
% clear; close all; rng(1);
% set(groot, ...
%   'defaultAxesFontSize',16, ...
%   'defaultLineLineWidth',2);
% 
% %% Settings
% reps      = 100;
% betaVals  = [0.5];
% bitsList  = [4, 6, 8, 10, 12];
% n         = 1000;
% d         = 100;
% % rename SR→LQ, MR→NLQ
% methods   = { 'regular_LQ','batched_LQ','regular_NLQ','batched_NLQ' };
% 
% %% Precompute δ’s
% deltaUni = zeros(numel(betaVals), numel(bitsList));
% deltaExp = zeros(numel(betaVals), numel(bitsList));
% alphaExp = zeros(numel(betaVals), numel(bitsList));
% for bIdx = 1:numel(betaVals)
%   lambda = betaVals(bIdx);
%   fprintf('Precompute quant params for λ = %.1f\n', lambda);
%   [Xref,~,~] = generateData(n, d, lambda);
%   for i = 1:numel(bitsList)
%     B = bitsList(i);
%     [du, de, ae] = pickQuantParams(Xref, B);
%     deltaUni(bIdx,i) = du;
%     deltaExp(bIdx,i) = de;
%     alphaExp(bIdx,i) = ae;
%     fprintf('  bits=%2d: δ_u=%.3e, δ_e=%.3e, α=%.3e\n', B, du, de, ae);
%   end
% end
% 
% %% Run & collect
% errAll = nan(numel(methods), reps, numel(bitsList), numel(betaVals));
% for bIdx = 1:numel(betaVals)
%   lambda = betaVals(bIdx);
%   fprintf('=== Starting λ = %.1f ===\n', lambda);
%   for i = 1:numel(bitsList)
%     B  = bitsList(i);
%     du = deltaUni(bIdx,i);
%     de = deltaExp(bIdx,i);
%     ae = alphaExp(bIdx,i);
%     fprintf('  -> bits = %2d, δ_u=%.3e, δ_e=%.3e\n', B, du, de);
% 
%     [X,~,v1,lambda1,lambda2] = generateData(n, d, lambda);
%     w0 = randn(d,1); w0 = w0/norm(w0);
%     eta = 2*log(n)/(n*(lambda1 - lambda2));
%     nb  = 25;
% 
%     for rep = 1:reps
%       fprintf('    rep %d/%d\n', rep, reps);
% 
%       % LQ online
%       fprintf('      method: regular_LQ ... ');
%       [~,errs] = regularOja_lowPrecision( ...
%                X, w0, eta, v1, @stochasticRound, du, [] );
%       errAll(1,rep,i,bIdx) = errs(end);
%       fprintf('done (%.3e)\n', errs(end));
% 
%       % LQ batched
%       fprintf('      method: batched_LQ ... ');
%       [~,errs] = batchedOja_lowPrecision( ...
%                X, w0, eta*(n/nb), v1, @stochasticRound, du, [], nb );
%       errAll(2,rep,i,bIdx) = errs(end);
%       fprintf('done (%.3e)\n', errs(end));
% 
%       % NLQ online
%       fprintf('      method: regular_NLQ ... ');
%       [~,errs] = regularOja_lowPrecision( ...
%                X, w0, eta, v1, @mantissaRound, de, ae );
%       errAll(3,rep,i,bIdx) = errs(end);
%       fprintf('done (%.3e)\n', errs(end));
% 
%       % NLQ batched
%       fprintf('      method: batched_NLQ ... ');
%       [~,errs] = batchedOja_lowPrecision( ...
%                X, w0, eta*(n/nb), v1, @mantissaRound, de, ae, nb );
%       errAll(4,rep,i,bIdx) = errs(end);
%       fprintf('done (%.3e)\n', errs(end));
%     end
%   end
% end
% 
% %% Save & plot
% fprintf('Saving resultsExp3.mat and plotting Exp3...\n');
% save('resultsExp3.mat', ...
%   'errAll','methods','bitsList','betaVals','reps','deltaUni','deltaExp');
% 
% % Compute mean & SE
% mu = squeeze(mean(errAll,2));    % [methods × bitsList × betaVals]
% se = squeeze(std(errAll,[],2)/sqrt(reps));
% 
% % Line+marker styles
% lineStyles = {'-o','--s',':d','-.^'};
% 
% % Create figure
% figure('Name','Error vs \delta','NumberTitle','off','Color','w');
% for bIdx = 1:numel(betaVals)
%     ax = subplot(1,numel(betaVals),bIdx);
%     hold(ax,'on');
%     grid(ax,'on'); grid(ax,'minor');
%     ax.GridLineStyle     = '--';
%     ax.MinorGridLineStyle= ':';
%     ax.GridAlpha         = 0.7;
%     ax.MinorGridAlpha    = 0.4;
%     ax.FontSize    = 25;
%     ax.FontWeight  = 'bold';
%     ax.LineWidth   = 1.5;
%     % ax.XScale      = 'log';
%     % ax.YScale      = 'log';
% 
%     for m = 1:numel(methods)
%         isLQ   = contains(methods{m}, '_LQ');
%         deltas = isLQ * deltaUni(bIdx,:) + ~isLQ * deltaExp(bIdx,:);
%         errorbar(...
%           ax, deltas, ...
%           mu(m,:,bIdx), ...
%           se(m,:,bIdx), ...
%           lineStyles{m}, ...
%           'DisplayName', methods{m}, ...
%           'LineWidth', 3, ...
%           'MarkerSize', 8);
%     end
% 
%     xlabel(ax, '\delta', 'FontSize',25, 'FontWeight', 'bold');
%     ylabel(ax, 'Final sin^2-error', 'FontSize', 25, 'FontWeight', 'bold');
%     title(ax, sprintf('\\lambda = %.1f, n = %d, d = %d', ...
%                       betaVals(bIdx), n, d), ...
%           'FontSize',25,'FontWeight','bold');
%     legend(ax, 'Interpreter','none', ...
%            'Location','best','FontSize',20);
%     hold(ax,'off');
% end


% runExp3_varyBits.m
% — Experiment 3: fix n & d, vary bits → Error vs bits (with LQ/NLQ and debug prints)

clear; close all; rng(0);
set(groot, ...
  'defaultAxesFontSize',16, ...
  'defaultLineLineWidth',2);

%% Settings
reps      = 100;
betaVals  = [0.5];
bitsList  = [4, 6, 8, 10, 12];
n         = 1000;
d         = 100;
methods  = { ...
  'standard',  'standard_LQ',  'standard_NLQ', 'batched',  'batched_LQ',  'batched_NLQ' ...
};

%% Precompute quant params
deltaUni = zeros(numel(betaVals), numel(bitsList));
deltaExp = zeros(numel(betaVals), numel(bitsList));
alphaExp = zeros(numel(betaVals), numel(bitsList));
for bIdx = 1:numel(betaVals)
  lambda = betaVals(bIdx);
  fprintf('Precompute quant params for λ = %.1f\n', lambda);
  [Xref,~,~] = generateData(n, d, lambda);
  for i = 1:numel(bitsList)
    B = bitsList(i);
    [du, de, ae] = pickQuantParams(Xref, B);
    deltaUni(bIdx,i) = du;
    deltaExp(bIdx,i) = de;
    alphaExp(bIdx,i) = ae;
    fprintf('  bits=%2d: δ_uni=%.3e, δ_exp=%.3e, α=%.3e\n', ...
            B, du, de, ae);
  end
end

%% Run & collect
errAll = nan(numel(methods), reps, numel(bitsList), numel(betaVals));
for bIdx = 1:numel(betaVals)
  lambda = betaVals(bIdx);
  fprintf('=== Starting λ = %.1f ===\n', lambda);
  
  for i = 1:numel(bitsList)
    B  = bitsList(i);
    du = deltaUni(bIdx,i);
    de = deltaExp(bIdx,i);
    ae = alphaExp(bIdx,i);
    fprintf('  -> bits = %2d, δ_uni=%.3e, δ_exp=%.3e\n', B, du, de);

    [X,~,v1,lambda1,lambda2] = generateData(n, d, lambda);
    w0 = randn(d,1); w0 = w0/norm(w0);
    eta = 2*log(n)/(n*(lambda1 - lambda2));
    nb  = 25;
    
    for rep = 1:reps
      fprintf('    rep %d/%d\n', rep, reps);
      
      fprintf('      method: standard ... ');
      [~,errs] = regularOja(X, w0, eta, v1);
      errAll(1,rep,i,bIdx) = errs(end);

      % regular_LQ
      fprintf('      method: standard_LQ ... ');
      [~,errs] = regularOja_lowPrecision( ...
               X, w0, eta, v1, @stochasticRound, du, [] );
      errAll(2,rep,i,bIdx) = errs(end);
      fprintf('done (err=%.3e)\n', errs(end));

      % regular_NLQ
      fprintf('      method: standard_NLQ ... ');
      [~,errs] = regularOja_lowPrecision( ...
               X, w0, eta, v1, @mantissaRound, de, ae );
      errAll(3,rep,i,bIdx) = errs(end);
      fprintf('done (err=%.3e)\n', errs(end));

      fprintf('      method: batched ... ');
      [~,errs] = batchedOja(X, w0, eta*(n/B), v1, B);
      errAll(4,rep,i,bIdx) = errs(end);
      
      % batched_LQ
      fprintf('      method: batched_LQ ... ');
      [~,errs] = batchedOja_lowPrecision( ...
               X, w0, eta*(n/nb), v1, @stochasticRound, du, [], nb );
      errAll(5,rep,i,bIdx) = errs(end);
      fprintf('done (err=%.3e)\n', errs(end));
      
      % batched_NLQ
      fprintf('      method: batched_NLQ ... ');
      [~,errs] = batchedOja_lowPrecision( ...
               X, w0, eta*(n/nb), v1, @mantissaRound, de, ae, nb );
      errAll(6,rep,i,bIdx) = errs(end);
      fprintf('done (err=%.3e)\n', errs(end));
    end
  end
end

%% Save & plot
fprintf('Saving resultsExp3.mat and plotting Exp3 vs bits...\n');
save('resultsExp3.mat', 'errAll','methods','bitsList','betaVals','reps');

% Compute mean & SE
mu = squeeze(mean(errAll,2));    % [methods × bitsList × betaVals]
se = squeeze(std(errAll,[],2)/sqrt(reps));

% Line+marker styles
% lineStyles = {'-o','--s',':d','-.^'};
lineStyles = { '-','-','-','-.','-.','-.' };
markers = {'.','.','.','d','d','d'};

figure('Name','Error vs bits','NumberTitle','off','Color','w');
for bIdx = 1:numel(betaVals)
    ax = subplot(1,numel(betaVals),bIdx);
    hold(ax,'on');
    grid(ax,'on'); grid(ax,'minor');
    ax.GridLineStyle     = '--';
    ax.MinorGridLineStyle= ':';
    ax.GridAlpha         = 0.7;
    ax.MinorGridAlpha    = 0.4;
    ax.FontSize    = 25;
    ax.FontWeight  = 'bold';
    ax.LineWidth   = 1.5;
    ax.XScale      = 'linear';
    ax.YScale      = 'log';

    for m = 1:numel(methods)
        errorbar(...
          ax, bitsList, ...
          mu(m,:,bIdx), ...
          se(m,:,bIdx), ...
          lineStyles{m}, ...
          'DisplayName', methods{m}, ...
          'LineWidth', 3, ...
          'MarkerSize', 8, ...
          'Marker', markers{m});
    end

    xlabel(ax, 'Bits',             'FontSize',25,'FontWeight','bold');
    ylabel(ax, 'Final sin^2-error',   'FontSize',25,'FontWeight','bold');
    title(ax, sprintf('\\lambda = %.1f, n = %d, d = %d', ...
                      betaVals(bIdx), n, d), ...
          'FontSize',25,'FontWeight','bold');
    legend(ax, 'Interpreter','none', ...
           'Location','best','FontSize',20);
    hold(ax,'off');
end