% State evolution of correct AMP 
% Wigner matrix/semicircle spectrum with mismatched SNR

clear;
close all;
clc;

SNR = [10, 11]; % values of the SNR
alphagrid = sqrt(SNR);
epsl = 0.1; % correlation of initialization

niter = 20; % number of iterations of AMP

% These choices of alphagrid and niter ensure that the program 
% runs fast on a laptop. The parameters employed to obtain the results
% reported in Figure 1 are specified in Section 4 of the paper

gamma = 2; % SNR mismatch parameter

% computation of the limit free cumulants
max_it = niter;
freecumWIG = zeros(1, 2*max_it);
freecumWIG(2) = gamma^2;

freecum = freecumWIG; % free cumulants (starting from the 1st)
    
for j = 1 : length(alphagrid)
    
    alpha = alphagrid(j);     
    fprintf('alpha=%f\n', alpha);
            
    % allocate vectors for SE recursion
    muSE = zeros(niter, 1);
    sigmaSE = zeros(niter, niter);
    DeltaMAT = zeros(niter, niter);
    Phi = zeros(niter+1, niter+1);
    scal = zeros(niter, 1);
    MSE = zeros(niter, 1);
        
    % initialization of SE recursion        
    muSE(1) = alpha * epsl;
    sigmaSE(1, 1) = freecum(2);
    scal(1) = epsl;
    
    MSE(1) = 1-epsl^2;
    
    DeltaMAT(1, 1) = 1;
    fprintf('Iteration %d, scal=%f, MSE=%f\n', 1, scal(1), MSE(1));
        
    for jj = 2 : niter
        
        Sigmared = sigmaSE(1:jj-1, 1:jj-1);
        
        Phi(jj, 1:jj-1) = muSE(1:jj-1)' / ( Sigmared + muSE(1:jj-1) * muSE(1:jj-1)');    
    
        muSE(jj) = alpha * muSE(1:jj-1)' / ( Sigmared + muSE(1:jj-1) * muSE(1:jj-1)') * muSE(1:jj-1);
        
        DeltaMAT(1, jj) = muSE(jj)/alpha * epsl;
        DeltaMAT(jj, 1) = DeltaMAT(1, jj);
        
        for i1 = 1 : jj-1
            DeltaMAT(jj, i1+1) = muSE(1:jj-1)' / (Sigmared + muSE(1:jj-1) * muSE(1:jj-1)') ...
                * (Sigmared(1:jj-1, 1:i1) + muSE(1:jj-1) * muSE(1:i1)') * ...
                (muSE(1:i1)' / (Sigmared(1:i1, 1:i1) + muSE(1:i1) * muSE(1:i1)'))';
            DeltaMAT(i1+1, jj) = DeltaMAT(jj, i1+1);
        end
    
        Sigmared = zeros(jj, jj);
        Phired = Phi(1:jj, 1:jj);
        Deltared = DeltaMAT(1:jj, 1:jj);

        for i1 = 0 : 2*(jj-1)
            ThetaMAT = zeros(jj, jj);

            for i2 = 0 : i1
                ThetaMAT = ThetaMAT + Phired^i2 * Deltared * (Phired')^(i1-i2);
            end

            Sigmared = Sigmared + freecum(i1+2) * ThetaMAT;
        end

        sigmaSE(1:jj, 1:jj) = Sigmared;                
        scal(jj) = muSE(jj)/alpha/sqrt(DeltaMAT(jj, jj));
        MSE(jj) = (1 - 2 * (muSE(jj)/alpha)^2 + DeltaMAT(jj, jj)^2)/2;
        fprintf('Iteration %d, scal=%f, MSE=%f\n', jj, scal(jj), MSE(jj));
                
    end
    
    
end