% Correct AMP 
% Wigner matrix/semicircle spectrum with mismatched SNR

clear;
close all;
clc;

SNR = [6, 7]; % values of the SNR
alphagrid = sqrt(SNR);
epsl = 0.1; % correlation of initialization
n = 1000; % length of the signal
ntrials = 2; % number of Montecarlo trials
niter = 30; % number of iterations of AMP
gamma = 2; % SNR mismatch parameter

% These choices of SNR, n, ntrials, niter ensure that the program 
% runs fast on a laptop. The parameters employed to obtain the results
% reported in Figure 1 are specified in Section 4 of the paper

% computation of the limit free cumulants
max_it = niter;
freecumWIG = zeros(1, 2*max_it+2);
freecumWIG(2) = gamma^2;
freecum = freecumWIG; % free cumulants (starting from the 1st)

for j = 1 : length(alphagrid)
    
    alpha = alphagrid(j);
  
    for i = 1 : ntrials
        
        fprintf('alpha=%f, trial #%d\n', alpha, i);

        % The signal u has a spherical prior
        u = randn(n, 1);
        normu = sqrt(sum(u.^2));
        u = sqrt(n)*u/normu;
        
        % Wigner noise spectrum
        A = randn(n, n);
        WWIG = 1/sqrt(2*n) * (A + A');
        
        X = alpha/n * (u * u') + gamma * WWIG;
                
        % initialization 
        v0 = epsl * u + sqrt(1-epsl^2) * randn(n, 1);
        normv0 = sqrt(sum(v0.^2));
        u_init = sqrt(n) * v0/normv0;
        
        % allocate vectors for AMP iterations
        uAMP = zeros(n, niter+1);
        fAMP = zeros(n, niter);
        muSE = zeros(niter, 1);
        sigmaSE = zeros(niter, 1);
        SigmaMAT = zeros(niter, niter);
        DeltaMAT = zeros(niter, niter);
        Phi = zeros(niter+1, niter+1);
        scal = zeros(niter, 1);
        MSE = zeros(niter, 1);
        
        % initialization of SE parameters
        muSE(1) = alpha * epsl;
        sigmaSE(1) = freecum(2);
        SigmaMAT(1, 1) = sigmaSE(1);
        
        % first AMP iterate
        uAMP(:, 1) = u_init;

        % MSE and normalized squared correlation
        scal(1) = (sum(uAMP(:, 1).* u))^2/sum(u.^2)/sum(uAMP(:, 1).^2);        
        MSE(1) = 1/n^2 * ( sum(u.^2)^2 + sum(uAMP(:, 1).^2)^2 - 2 * (sum(uAMP(:, 1).* u))^2 )/2;
        
        fprintf('Iteration %d, scal=%f, MSE=%f\n', 1, scal(1), MSE(1));
        
        DeltaMAT(1, 1) = (uAMP(:, 1)' * uAMP(:, 1))/n;
        
        % computation of the memory coefficients 
        b11 = freecum(1);
        
        fAMP(:, 1) = X * uAMP(:, 1) - b11 * uAMP(:, 1);

        % second AMP iterate
        uAMP(:, 2) = muSE(1)/(sigmaSE(1) + muSE(1)^2)*fAMP(:, 1);
        
        % MSE and normalized squared correlation
        scal(2) = (sum(uAMP(:, 2).* u))^2/sum(u.^2)/sum(uAMP(:, 2).^2);
        MSE(2) = 1/n^2 * ( sum(u.^2)^2 + sum(uAMP(:, 2).^2)^2 - 2 * (sum(uAMP(:, 2).* u))^2 )/2;
        
        Phi(2, 1) = muSE(1)/(sigmaSE(1) + muSE(1)^2);
        DeltaMAT(1, 2) = (uAMP(:, 2)' * uAMP(:, 1))/n;
        DeltaMAT(2, 2) = (uAMP(:, 2)' * uAMP(:, 2))/n;
        DeltaMAT(2, 1) = DeltaMAT(1, 2);
        
        fprintf('Iteration %d, scal=%f, MSE=%f\n', 2, scal(2), MSE(2));

        for jj = 2 : niter-1
            
           % computation of the memory coefficients 
           Phired = Phi(1:jj, 1:jj); 
           Deltared = DeltaMAT(1:jj, 1:jj);
           
           B = zeros(jj, jj);
           
           for ii = 0 : jj-1
               B = B + freecum(ii+1) * Phired^ii;
           end
           
           b = B(jj, 1:jj);
           
           fAMP(:, jj) = X * uAMP(:, jj) - sum(repmat(b, n, 1) .* uAMP(:, 1:jj), 2);

           % estimate SE parameters from data                      
           Sigmared = zeros(jj, jj);
           
           for i1 = 0 : 2*(jj-1)
               ThetaMAT = zeros(jj, jj);
               
               for i2 = 0 : i1
                   ThetaMAT = ThetaMAT + Phired^i2 * Deltared * (Phired')^(i1-i2);
               end
               
               Sigmared = Sigmared + freecum(i1+2) * ThetaMAT;
           end
           
           muSE(jj) = sqrt(abs(sum(fAMP(:, jj).^2)/n - Sigmared(jj, jj)));

           % (t+1)-st AMP iterate
           uAMP(:, jj+1) = (muSE(1:jj)' / (Sigmared + muSE(1:jj) * muSE(1:jj)') * fAMP(:, 1:jj)')';
           
           % MSE and normalized squared correlation
           scal(jj+1) = (sum(uAMP(:, jj+1).* u))^2/sum(u.^2)/sum(uAMP(:, jj+1).^2);         
           MSE(jj+1) = 1/n^2 * ( sum(u.^2)^2 + sum(uAMP(:, jj+1).^2)^2 - 2 * (sum(uAMP(:, jj+1).* u))^2 )/2;

           for i1 = 1 : jj+1
               DeltaMAT(i1, jj+1) = (uAMP(:, i1)' * uAMP(:, jj+1))/n;
               DeltaMAT(jj+1, i1) = DeltaMAT(i1, jj+1);
           end
                      
           Phi(jj+1, 1:jj) = muSE(1:jj)' / (Sigmared + muSE(1:jj) * muSE(1:jj)');
          
           fprintf('Iteration %d, scal=%f, MSE=%f\n', jj+1, scal(jj+1), MSE(jj+1));
           
        end
        
    end
end