function test_PM_vs_SPM_L_implicit(n)

%clearvars
%clc
addpath 'helper_functions/'

tic

Algs = {{'PM', @PM_test},...
        {'SPM', @SPM_test},...
        };
    
nAlgs = length(Algs);

ntries_inner = 10;
ntries_outer = 100;
ntries_tot = ntries_inner * ntries_outer;

%% L test

nLR = 10;
%n = 2;
vecL = round(exp(linspace(log(10), log(40), nLR)));
%scale_n = 1.5*factorial(n);
%vecRL = round([scale_n*vecL; vecL.^((n+1)/2); vecL.^n / scale_n]);
vecRL = round([3*vecL; vecL.^1.5; vecL.^2 / 3]);
nRR = size(vecRL, 1);

noise = 0;

err_all_LR = zeros(nLR, nRR, nAlgs, ntries_tot);
err_LR = zeros(nLR, nRR, nAlgs);
err_std_LR = zeros(nLR, nRR, nAlgs);


for iter=1:nLR
    
    L = vecL(iter);
    
    for k=1:nRR
        R = vecRL(k, iter);
    
    
        [err_LR(iter, k, :), err_std_LR(iter, k, :), err_all_] = ...
            benchmark_tester(L, R, n, Algs, ntries_inner, ntries_outer);
        err_all_LR (iter, k, :, :) = reshape(err_all_, [], nAlgs)';
    end
    
end

totaltime = toc

save(sprintf("results/PM_vs_SPM_L_%d.mat", n))

end

function [err_mean, err_std, err_vec] = ...
            benchmark_tester(L, R, n, Algs, ntries_inner, ntries_outer)
    
    
    nAlgs = length(Algs);
    err_vec = zeros(ntries_inner, ntries_outer, nAlgs);
    
    fprintf('L=%d, R=%d, n=%d\n', L, R, n);
    
    for outer = 1:ntries_outer
        
        % Generate A and T as explained in the numerical section
        A_true = randn(L, R);
        A_true = A_true./vecnorm(A_true);
        
        lambda_true = (.5 + 1.5 * rand(1,R)) / sqrt(R);
        
        
        
        for i=1:nAlgs
            % Run each algorithm and calculate the error
            A_est = Algs{i}{2}(A_true, lambda_true, 2*n, ntries_inner);
            [~, I] = min(2 - 2 * abs(A_est' * A_true), [], 2);
            s = 2 * (sum(A_est .* A_true(:, I)) > 0) - 1;
            err_vec(:, outer, i) = vecnorm(A_est - s.* A_true(:, I));
        end

    end
    
    err_mean = mean(err_vec, [1 2]);
    err_std = std(err_vec,0, [1 2]);
    

end

function X = SPM_test(A, lambda, n, ntries)

maxiter = 10000;

n2 = n / 2;

A = A./vecnorm(A);

L = size(A, 1);

if n2<=4
  cn = sqrt(2*(n2-1)/n2);
else
  cn = (2-sqrt(2))*sqrt(n2);
end                      

decG = decomposition((A'*A).^n2);

X = zeros(L, ntries);

for k = 1:ntries

    % Initialize v
    v = randn(L,1);
    v = v/norm(v);

    for tries=1:maxiter

        vX = v'*A;

        v_new = ((vX.^(n2-1)).*A)*(decG \ (vX.^n2)');

        f = v_new'*v;

        % Determine optimal shift
        % Sometimes due to numerical error f can be greater than 1
        f_ = max(min(f,1),.5);
        clambda = sqrt(f_*(1-f_));
        shift = cn*clambda;

        % Shifted power method
        v_new = v_new + shift*v;
        v_new = v_new/norm(v_new);
        err = norm(v - v_new);
        v = v_new;

        if err < 1e-12
            % Algorithm converged
            break
        end
        
    end
    
    X(:, k) = v; 
    
end

end


function X = PM_test(A, lambda, n, ntries)

    L = size(A,1);   

    %% Set options here
    maxiter = 10000;
    gradtol = 1e-12;

    X = zeros(L, ntries);

    for k = 1:ntries

        % Initialize Xk
        Ak = randn(L,1);
        Ak = Ak/norm(Ak);

        for iter = 1:maxiter

            % Calculate power of Xk
            Ak_new = Ak + A * ((Ak'* A).^(n-1) .* lambda)';
            Ak_new = Ak_new/norm(Ak_new);

            if norm(Ak - Ak_new) < gradtol
                % Algorithm converged
                Ak = Ak_new;
                break
            else
                Ak = Ak_new;
            end

        end
          
        X(:, k) = Ak;      
        
    end

end