% State evolution of correct AMP 
% Free convolution of Rademacher and semicircle spectra

clear;
close all;
clc;


t = 0.5; % interpolation parameter (for the spectrum)
SNR = [10, 11]; % values of the SNR
alphagrid = sqrt(SNR);
epsl = 0.1; % correlation of initialization

niter = 40; % number of iterations of AMP

% These choices of alphagrid and niter ensure that the program 
% runs fast on a laptop. The parameters employed to obtain the results
% reported in Figures 1-2 are specified in Section 4 of the paper


% computation of the limit free cumulants
max_it = niter;

freecumRAD = zeros(1, 2*max_it); % free cumulants of a Rademacher spectrum (starting from the 1st)

seq = 2 * (3-2*[1:max_it])./[1:max_it];

for i = 1 : max_it
    
    freecumRAD(2*i) = (1-t)^i * prod(seq(1:i))/2;
end

freecumWIG = zeros(1, 2*max_it); % free cumulants of a Wigner spectrum (starting from the 1st)
freecumWIG(2) = t * 1;

freecum = freecumWIG + freecumRAD; % free cumulants (starting from the 1st)
    
for j = 1 : length(alphagrid)
    
    alpha = alphagrid(j);     
    fprintf('alpha=%f\n', alpha);
            
    % allocate vectors for SE recursion
    muSE = zeros(niter, 1);
    sigmaSE = zeros(niter, niter);
    DeltaMAT = zeros(niter, niter);
    Phi = zeros(niter+1, niter+1);
    scal = zeros(niter, 1);
    MSE = zeros(niter, 1);
        
    % initialization of SE recursion
    muSE(1) = alpha * epsl;
    sigmaSE(1, 1) = freecum(2);
    scal(1) = epsl;
    MSE(1) = 1-epsl^2;
    
    DeltaMAT(1, 1) = 1;
    
    fprintf('Iteration %d, scal=%f, MSE=%f\n', 1, scal(1), MSE(1));
    
    
        
    for jj = 2 : niter
        
        Sigmared = sigmaSE(1:jj-1, 1:jj-1);
        
        Phi(jj, 1:jj-1) = muSE(1:jj-1)' / ( Sigmared + muSE(1:jj-1) * muSE(1:jj-1)');    
    
        muSE(jj) = alpha * muSE(1:jj-1)' / ( Sigmared + muSE(1:jj-1) * muSE(1:jj-1)') * muSE(1:jj-1);
        
        DeltaMAT(1, jj) = muSE(jj)/alpha * epsl;
        DeltaMAT(jj, 1) = DeltaMAT(1, jj);
        
        for i1 = 1 : jj-1
            DeltaMAT(jj, i1+1) = muSE(1:jj-1)' / (Sigmared + muSE(1:jj-1) * muSE(1:jj-1)') ...
                * (Sigmared(1:jj-1, 1:i1) + muSE(1:jj-1) * muSE(1:i1)') * ...
                (muSE(1:i1)' / (Sigmared(1:i1, 1:i1) + muSE(1:i1) * muSE(1:i1)'))';
            DeltaMAT(i1+1, jj) = DeltaMAT(jj, i1+1);
        end
    
        Sigmared = zeros(jj, jj);
        Phired = Phi(1:jj, 1:jj);
        Deltared = DeltaMAT(1:jj, 1:jj);

        for i1 = 0 : 2*(jj-1)
            ThetaMAT = zeros(jj, jj);

            for i2 = 0 : i1
                ThetaMAT = ThetaMAT + Phired^i2 * Deltared * (Phired')^(i1-i2);
            end

            Sigmared = Sigmared + freecum(i1+2) * ThetaMAT;
        end

        sigmaSE(1:jj, 1:jj) = Sigmared;                
        scal(jj) = muSE(jj)/alpha/sqrt(DeltaMAT(jj, jj));
        MSE(jj) = (1 - 2 * (muSE(jj)/alpha)^2 + DeltaMAT(jj, jj)^2)/2;
        fprintf('Iteration %d, scal=%f, MSE=%f\n', jj, scal(jj), MSE(jj));
                
    end
        
end