function test_PM_vs_SPM_noise_f(L, R, n)

addpath 'helper_functions/'

tic

Algs = {{'PM', @PM_test},...
        {'SPM', @SPM_test},...
        };
    
nAlgs = length(Algs);

ntries_inner = 20;
ntries_outer = 50;
ntries_tot = ntries_inner * ntries_outer;

%% Noise test

nn = 11;
vecnoise = 10.^linspace(-5, 0, nn);

err_all_n = zeros(nn, 2, ntries_tot);
err_n = zeros(nn, 2);
err_std_n = zeros(nn, 2);

for iter=1:nn
    
    noise = vecnoise(iter);
    
    [err_n(iter, :), err_std_n(iter, :), err_all_] = ...
        benchmark_tester(L, R, n, noise, Algs, ntries_inner, ntries_outer);
    err_all_n(iter, :, :) = reshape(err_all_, [], nAlgs)';                    
end

totaltime = toc

save(sprintf('results/PM_vs_SPM_noise_%d.mat', n/2));

% %% noise test 
% 
% hf = figure(1);
% hf.Position = [100 100 300 200];
% ax = gca;
%     
% %ax = subplot('Position', [0.72, 0.15, 0.25, 0.75]);
% [hpl, hpa] = error_shaded(vecnoise, err_n, err_std_n);
% hpl(2).LineStyle = '--';
% for hplk=hpl.'
%     hplk.LineWidth = 1.5;
% end
% for hpai=hpa
%     hpai.YData = max(hpai.YData, 1e-14);
% end
% %h(1) = errorbar(vecL, err_LR(:,1), err_std_LR(:,1));
% %hold on;
% %h(2) = errorbar(vecL, err_LR(:,2), err_std_LR(:,2));
% 
% ax.XScale = 'log';
% ax.YScale = 'log';
% xlabel('$\sigma$','Interpreter','latex');
% ylabel('avg $\min_i\|x_*-a_i\|_2$','Interpreter','latex');
% legk = {'PM', 'SPM'};
% hl(3) = legend(ax,hpl, legk, 'Interpreter','latex','Location','southeast','FontSize',10);
% %hl(3).Position = [0.7292    0.5262    0.0797    0.1462];
% ylim([1e-6, 1]);
% 
% 
% set(ax,'XScale','log','YScale','log', 'FontSize',10,...
%         'TickLabelInterpreter','latex')
%     
% pdfprint(sprintf('results/PM_vs_SPM_noise_%d', n/2), hf)

function [err_mean, err_std, err_vec] = ...
            benchmark_tester(L, R, n, noise, Algs, ntries_inner, ntries_outer, scaling)
    
    
    nAlgs = length(Algs);
    err_vec = zeros(ntries_inner, ntries_outer, nAlgs);
    
    fprintf('L=%d, R=%d, noise=%1.1e\n', L, R, noise);
    
    for outer = 1:ntries_outer

        % Generate A and T as explained in the numerical section
        A_true = randn(L, R);
        A_true = A_true./vecnorm(A_true);
        
        lambda_true = (.5 + 1.5 * rand(1,R)) * sqrt(L^n / R);
        
        T = generate_lowrank_tensor(A_true, lambda_true, n);

        T_noise = (noise * sqrt(factorial(n))) * symmetrize_tensor(randn(L*ones(1,n)));

        T = T + T_noise;
        
        for i=1:nAlgs
            % Run each algorithm and calculate the error
            A_est = Algs{i}{2}(T, R, ntries_inner);
            [~, I] = min(2 - 2 * abs(A_est' * A_true), [], 2);
            s = 2 * (sum(A_est .* A_true(:, I)) > 0) - 1;
            err_vec(:, outer, i) = vecnorm(A_est - s.* A_true(:, I));
        end

    end
    
    err_mean = mean(err_vec, [1 2]);
    err_std = std(err_vec,0, [1 2]);
    

end

end

function X = SPM_test(T, R, ntries)

    L = size(T,1);
    n = round(log(numel(T))/log(L));
    
    n2 = n/2;
    d = L^n2;
    
    assert(mod(n,2)==0 && n > 0,'n is not even an even positive integer');

    %% Set options here
    maxiter = 10000;
    gradtol = 1e-12;
    eigtol = 1e-1;

    % Flatten T
    T = reshape(T,d,d);

    %%
    % In this block of code we exploit the fact that each column and row of   
    % mat(T) is a symmetric tensor of order d and therefore has a lot of 
    % repeated entries. Instead of calculating the eigen decomposition of
    % mat(T), which has L^(2d) entries, we calculate an equivalent eigen
    % decomposition of a matrix with roughly (L^d/d!)^2 entries, this way 
    % getting a speed up of approximately (d!)^3.

    
    [symind, findsym, symindscale] = symmetric_indices(L, n2);

    symindscale = sqrt(symindscale);
    findsym = reshape(findsym,[],1);
    findsymscale = 1./symindscale(findsym);
    symind = (symind-1)*(L.^(0:n2-1)')+1;
    %%
    % Eigen decomposition
    [symV, D] = eig2(symindscale.*T(symind,symind).*symindscale');

    % Determine tensor rank by the eigenvalues of mat(T)
    if isempty(R)
        R = sum(abs(D) > eigtol);
    end

    V = findsymscale.*symV(findsym,1:R);
    
    % C_n from Lemma 4.7 in SPM paper
    if n2<=4
      cn = sqrt(2*(n2-1)/n2);
    else
      cn = (2-sqrt(2))*sqrt(n2);
    end
    
    X = zeros(L, ntries);
    
    V = reshape(V,[],L*R);

    for k = 1:ntries
            
        % Initialize Xk
        Ak = randn(L,1);
        Ak = Ak/norm(Ak);

        for iter = 1:maxiter

            % Calculate power of Xk
            Apow = Ak;
            for i=2:n2-1
                Apow = reshape(Apow.*Ak',[],1);
            end

            % Calculate contraction of V with x^(n2-1)
            VAk = reshape(Apow'*V,L,R);

            Ak_new = VAk*(Ak'*VAk)';

            f = Ak_new'*Ak;

            % Determine optimal shift
            % Sometimes due to numerical error f can be greater than 1
            f_ = max(min(f,1),.5);
            clambda = sqrt(f_*(1-f_));
            shift = cn*clambda;

            % Shifted power method
            Ak_new = Ak_new + shift*Ak;
            Ak_new = Ak_new/norm(Ak_new);

            if norm(Ak - Ak_new) < gradtol
                % Algorithm converged
                Ak = Ak_new;
                break
            else
                Ak = Ak_new;
            end

        end
          
       X(:, k) = Ak;      
        
    end
    


end

function X = PM_test(T, R, ntries)

    L = size(T,1);
    n = round(log(numel(T))/log(L));
    
    assert(mod(n,2)==0 && n > 0,'n is not even an even positive integer');

    %% Set options here
    maxiter = 10000;
    gradtol = 1e-12;

    % Flatten T
    T = reshape(T,[],L);

    X = zeros(L, ntries);

    for k = 1:ntries

        % Initialize Xk
        Ak = randn(L,1);
        Ak = Ak/norm(Ak);

        for iter = 1:maxiter

            % Calculate power of Xk
            Apow = Ak;
            for i=2:n-1
                Apow = reshape(Apow.*Ak',[],1);
            end

            Ak_new = Ak + (Apow' * T)';
            Ak_new = Ak_new/norm(Ak_new);

            if norm(Ak - Ak_new) < gradtol
                % Algorithm converged
                Ak = Ak_new;
                break
            else
                Ak = Ak_new;
            end

        end
          
        X(:, k) = Ak;      
        
    end

end