% AMP with PCA initialization for the square model (1.1) with
% uniform noise

clear;
close all;
clc;

alphagrid = [0.6, 1]; % values of the SNR \alpha
n = 1000; % length of the signal
ntrials = 2; % number of Montecarlo trials
niter = 10; % number of iterations of AMP with PCA initialization

% These choices of alphagrid, n, ntrials, niter ensure that the program 
% runs fast on a laptop. The parameters employed to obtain the results
% reported in Figures 1-2 are specified in Section 4 of the paper


% computation of the limit free cumulants

max_it = niter;
freecum = zeros(1, 2*max_it); % free cumulants (starting from the 1st)
freecum(2:2*max_it) = bernoulli(2:2*max_it)./factorial(2:2*max_it);

% if niter is too large, set the corresponding free cumulant to 0 (the
% formula above does not work well when niter is large)
if niter > 100
    freecum(200:end) = zeros(2*max_it-200+1, 1);
end

for j = 1 : length(alphagrid)
    
    alpha = alphagrid(j);

    for i = 1 : ntrials
        
        fprintf('alpha=%f, trial #%d\n', alpha, i);

        % The signal u has a Rademacher prior, i.e., its entries are i.i.d.
        % and uniform in {-1, 1}
        u = sign(rand(n, 1)-0.5);
        normu = sqrt(sum(u.^2));
        u = sqrt(n)*u/normu;

        % The eigenvalues of the noise are i.i.d. and uniformly distributed
        % in the interval [-1/2, 1/2]
        Lambda = diag(rand(n, 1)-0.5);
        
        % the SVD of a Wigner matrix gives an Haar-distributed matrix
        G = randn(n,n);
        A = (G + G') / sqrt(2*n);
        [W, D] = eig(A); 

        X = alpha/n * (u * u') + W * Lambda * W';
        
        % initialization with PCA
        [Vc, Dc] = eig(X);
        Vr = real(Vc);
        Dr = real(Dc);
        [valeig, indeig]=sort(diag(Dr), 'descend');
        v0 = Vr(:, indeig(1));
        u_init = sqrt(n) * v0;
        
        % alpha1 is the estimate of the SNR \alpha obtained from the 
        % largest eigenvalue of the data matrix X
        if valeig(1) > (0.5 + 10^(-6))
            alpha1 = 1/log((2*valeig(1)+1)/(2*valeig(1)-1));
        else
            % if we are too close to the PCA threshold, the formula above
            % does not work well. Thus, we just set alpha1 to be some
            % small value (e.g., 0.05)
            alpha1 = 0.05;
        end        
        
        % allocate vectors for AMP iterations
        uAMP = zeros(n, niter+1);
        fAMP = zeros(n, niter);
        muSE = zeros(niter, 1);
        sigmaSE = zeros(niter, 1);
        avgder = zeros(niter, 1);
        scal = zeros(niter, 1);
        
        
        % first iteration of AMP with PCA initialization
        uAMP(:, 1) = u_init;
        scal(1) = (sum(uAMP(:, 1).* u))^2/sum(u.^2)/sum(uAMP(:, 1).^2);
        fprintf('Iteration %d, scal=%f\n', 1, scal(1));

        % estimate SE parameters from their limit expressions (using the
        % value of \alpha obtained from the data)
        muSE(1) = alpha1 * sqrt( 1/alpha1^2 * (1-tanh(1/(2*alpha1))^2)/(4*tanh(1/(2*alpha1))^2) );
        sigmaSE(1) = alpha1^2 * (1- 1/alpha1^2 * (1-tanh(1/(2*alpha1))^2)/(4*tanh(1/(2*alpha1))^2));

        % computation of the memory coefficients   
        b11 = 1/(2 * tanh(1/(2*alpha1))) -alpha1;
        
        % iterate f^1
        fAMP(:, 1) = X * uAMP(:, 1) - b11 * uAMP(:, 1);
        
        % iterate u^2
        uAMP(:, 2) = tanh(muSE(1)/sigmaSE(1)*fAMP(:, 1));
        
        % normalized squared correlation between u^2 and the signal
        scal(2) = (sum(uAMP(:, 2).* u))^2/sum(u.^2)/sum(uAMP(:, 2).^2);
        fprintf('Iteration %d, scal=%f\n', 2, scal(2));
        
        % average of the derivative of the denoiser u_2
        avgder(1) = muSE(1)/sigmaSE(1) * ( 1 - mean((tanh(muSE(1)/sigmaSE(1)*fAMP(:, 1))).^2));
        
        for jj = 2 : niter-1
            
           % computation of the memory coefficients 
           b = zeros(1, jj);
           b(jj)=0;
           for ii = 0 : jj-3
               b(jj-ii-1) = freecum(ii+2) * prod(avgder(jj-ii-1 : jj-1));
           end
           b(1) = alpha1^(jj-1) * (b11 - sum( freecum(1:jj-1) ./ (alpha1.^(0:jj-2)))) * prod(avgder(1 : jj-1));
           
           % iterate f^t
           fAMP(:, jj) = X * uAMP(:, jj) - sum(repmat(b, n, 1) .* uAMP(:, 1:jj), 2);
           
           % estimation of \sigma_{t, t} from the data
           M = zeros(jj-1, jj-1);
           
           for i1 = 1 : jj-1
               for i2 = 1 : jj-1
                   M(i1, i2) = freecum(i1+i2) * prod(avgder(jj-i1+1 : jj-1)) ...
                       * prod(avgder(jj-i2+1 : jj-1)) * (uAMP(:, jj-i1+1)' * uAMP(:, jj-i2+1))/n;
               end
           end
           
           M1 = zeros(1, jj-1);
           
           for i2 = 1 : jj-1
               M1(1, i2) = alpha1^(jj+i2-2) * (b11 - sum( freecum(1:jj+i2-1) ./ (alpha1.^(0:jj+i2-2)))) ...
                   * prod(avgder(1 : jj-1)) ...
                   * prod(avgder(jj-i2+1 : jj-1)) * (fAMP(:, 1)' * uAMP(:, jj-i2+1))/n;
           end
           
           extrav = zeros(1, jj-1);
           
           for i1 = 1 : jj-1
               extrav(i1) = (b11 - sum( freecum(1:i1) ./ (alpha1.^(0:i1-1))));
           end
           
           extrac = zeros(jj-1, jj-1);
           
           for i1 = 1 : jj-1
               for i2 = 1 : jj-1
                    extrac(i1, i2) = freecum(i1+i2) /alpha1^(i1+i2-2);
               end
           end
           
           M2 = alpha1^(2*jj-4) * ( alpha1^2 + (tanh(1/(2*alpha1))^2 -1)/(4*tanh(1/(2*alpha1))^2) - ...
               2*alpha1*sum(extrav) + sum(sum(extrac)) ) * (prod(avgder(1 : jj-1)))^2 ...
                   * (fAMP(:, 1)' * fAMP(:, 1))/n;
           
           sigmaSE(jj) = sum(sum(M))+2*sum(M1)+M2;
           
           % given \sigma_{t, t}, \mu_t can be estimated from the norm of
           % the iterate f^t
           muSE(jj) = sqrt(abs(sum(fAMP(:, jj).^2)/n - sigmaSE(jj)));
           
           % average of the derivative of the denoiser u_{t+1}
           avgder(jj) = muSE(jj)/sigmaSE(jj) * ( 1 - mean((tanh(muSE(jj)/sigmaSE(jj)*fAMP(:, jj))).^2));

           % iterate u^{t+1}
           uAMP(:, jj+1) = tanh(muSE(jj)/sigmaSE(jj)*fAMP(:, jj));
           
           % normalized squared correlation between u^{t+1} and the signal
           scal(jj+1) = (sum(uAMP(:, jj+1).* u))^2/sum(u.^2)/sum(uAMP(:, jj+1).^2);
           fprintf('Iteration %d, scal=%f\n', jj+1, scal(jj+1));
           
        end
    end
end