% Correct AMP 
% Free convolution of Rademacher and semicircle spectra

clear;
close all;
clc;

SNR = [1.5, 2]; % values of the SNR
alphagrid = sqrt(SNR);
epsl = 0.1; % correlation of initialization
t = 0.5; % interpolation parameter (for the spectrum)
n = 1000; % length of the signal
ntrials = 2; % number of Montecarlo trials
niter = 30; % number of iterations of AMP

% These choices of SNR, n, ntrials, niter ensure that the program 
% runs fast on a laptop. The parameters employed to obtain the results
% reported in Figures 1-2 are specified in Section 4 of the paper


% computation of the limit free cumulants

max_it = niter;
freecumRAD = zeros(1, 2*max_it); % free cumulants of a Rademacher spectrum (starting from the 1st)
seq = 2 * (3-2*[1:max_it])./[1:max_it];
for i = 1 : max_it
    freecumRAD(2*i) = (1-t)^i * prod(seq(1:i))/2;
end

freecumWIG = zeros(1, 2*max_it); % free cumulants of a Wigner spectrum (starting from the 1st)
freecumWIG(2) = t * 1;

freecum = freecumWIG + freecumRAD; % free cumulants of the free convolution (starting from the 1st)

for j = 1 : length(alphagrid)
    
    alpha = alphagrid(j);
  
    for i = 1 : ntrials
        
        fprintf('alpha=%f, trial #%d\n', alpha, i);

        % The signal u has a spherical prior
        u = randn(n, 1);
        normu = sqrt(sum(u.^2));
        u = sqrt(n)*u/normu;

        % The spectrum of the noise matrix W = sqrt(t) * WWIG + sqrt(1-t) * WRAD  
        % is the free convolution of Rademacher and semicircle spectra
        A = randn(n, n);
        WWIG = 1/sqrt(2*n) * (A + A'); % Wigner matrix

        A1 = randn(n, n);
        [U, S, V] = svd(A1);
        WRAD = U * diag(sign(rand(n, 1)-0.5)) * U'; % matrix with Rademacher spectrum
        
        X = alpha/n * (u * u') + sqrt(t) * WWIG + sqrt(1-t) * WRAD;
                
        % initialization 
        v0 = epsl * u + sqrt(1-epsl^2) * randn(n, 1);
        normv0 = sqrt(sum(v0.^2));
        u_init = sqrt(n) * v0/normv0;
        
        % allocate vectors for AMP iterations
        uAMP = zeros(n, niter+1);
        fAMP = zeros(n, niter);
        muSE = zeros(niter, 1);
        sigmaSE = zeros(niter, 1);
        SigmaMAT = zeros(niter, niter);
        DeltaMAT = zeros(niter, niter);
        Phi = zeros(niter+1, niter+1);
        scal = zeros(niter, 1);
        MSE = zeros(niter, 1);
        
        % initialization of SE parameters
        muSE(1) = alpha * epsl;
        sigmaSE(1) = freecum(2);
        SigmaMAT(1, 1) = sigmaSE(1);
        
        % first AMP iterate
        uAMP(:, 1) = u_init;
        
        % MSE and normalized squared correlation
        scal(1) = (sum(uAMP(:, 1).* u))^2/sum(u.^2)/sum(uAMP(:, 1).^2);     
        MSE(1) = 1/n^2 * ( sum(u.^2)^2 + sum(uAMP(:, 1).^2)^2 - 2 * (sum(uAMP(:, 1).* u))^2 )/2;        
        
        fprintf('Iteration %d, scal=%f, MSE=%f\n', 1, scal(1), MSE(1));
        
        DeltaMAT(1, 1) = (uAMP(:, 1)' * uAMP(:, 1))/n;

        % computation of the memory coefficients 
        b11 = freecum(1);
        
        fAMP(:, 1) = X * uAMP(:, 1) - b11 * uAMP(:, 1);

        % second AMP iterate
        uAMP(:, 2) = muSE(1)/(sigmaSE(1) + muSE(1)^2)*fAMP(:, 1);

        % MSE and normalized squared correlation
        scal(2) = (sum(uAMP(:, 2).* u))^2/sum(u.^2)/sum(uAMP(:, 2).^2);
        MSE(2) = 1/n^2 * ( sum(u.^2)^2 + sum(uAMP(:, 2).^2)^2 - 2 * (sum(uAMP(:, 2).* u))^2 )/2;
        
        Phi(2, 1) = muSE(1)/(sigmaSE(1) + muSE(1)^2);
        DeltaMAT(1, 2) = (uAMP(:, 2)' * uAMP(:, 1))/n;
        DeltaMAT(2, 2) = (uAMP(:, 2)' * uAMP(:, 2))/n;
        DeltaMAT(2, 1) = DeltaMAT(1, 2);
        
        fprintf('Iteration %d, scal=%f, MSE=%f\n', 2, scal(2), MSE(2));

        for jj = 2 : niter-1
            
           % computation of the memory coefficients 
           Phired = Phi(1:jj, 1:jj); 
           Deltared = DeltaMAT(1:jj, 1:jj);
           
           B = zeros(jj, jj);
           
           for ii = 0 : jj-1
               B = B + freecum(ii+1) * Phired^ii;
           end
           
           b = B(jj, 1:jj);
           
           fAMP(:, jj) = X * uAMP(:, jj) - sum(repmat(b, n, 1) .* uAMP(:, 1:jj), 2);
           
           % estimate SE parameters from data
           Sigmared = zeros(jj, jj);
           
           for i1 = 0 : 2*(jj-1)
               ThetaMAT = zeros(jj, jj);
               
               for i2 = 0 : i1
                   ThetaMAT = ThetaMAT + Phired^i2 * Deltared * (Phired')^(i1-i2);
               end
               
               Sigmared = Sigmared + freecum(i1+2) * ThetaMAT;
           end
           
           muSE(jj) = sqrt(abs(sum(fAMP(:, jj).^2)/n - Sigmared(jj, jj)));

           % (t+1)-st AMP iterate
           uAMP(:, jj+1) = (muSE(1:jj)' / (Sigmared + muSE(1:jj) * muSE(1:jj)') * fAMP(:, 1:jj)')';
           
           % MSE and normalized squared correlation
           scal(jj+1) = (sum(uAMP(:, jj+1).* u))^2/sum(u.^2)/sum(uAMP(:, jj+1).^2);
           MSE(jj+1) = 1/n^2 * ( sum(u.^2)^2 + sum(uAMP(:, jj+1).^2)^2 - 2 * (sum(uAMP(:, jj+1).* u))^2 )/2;

           for i1 = 1 : jj+1
               DeltaMAT(i1, jj+1) = (uAMP(:, i1)' * uAMP(:, jj+1))/n;
               DeltaMAT(jj+1, i1) = DeltaMAT(i1, jj+1);
           end
                      
           Phi(jj+1, 1:jj) = muSE(1:jj)' / (Sigmared + muSE(1:jj) * muSE(1:jj)');
          
           fprintf('Iteration %d, scal=%f, MSE=%f\n', jj+1, scal(jj+1), MSE(jj+1));
           
        end
        
    end
end