% State evolution of Gaussian AMP 
% Uniform spectrum

clear;
close all;
clc;

SNR = [10, 11]; % values of the SNR
alphagridAMP = sqrt(SNR);
eps = 0.1; % correlation of initialization

niter = 20; % number of iterations of AMP

% These choices of alphagrid and niter ensure that the program 
% runs fast on a laptop. The parameters employed to obtain the results
% reported in Figures 1 are specified in Section 4 of the paper

% computation of the limit free cumulants
max_it = 2*niter+2;
freecum = zeros(1, max_it);

load Cumulants_UnifSpec_a_sqrt3.mat;
freecum(1:length(cumulants)) = cumulants;

scal_allu = zeros(niter, length(alphagridAMP));
MSEu = zeros(niter, length(alphagridAMP));
    
for j = 1 : length(alphagridAMP)
    
    alpha = alphagridAMP(j);     
    fprintf('State evolution\nSNR=%f\n', alpha^2);
                
    % allocate vectors for SE recursion
    muSE = zeros(niter, 1);
    betac = zeros(niter+1, 1);
    SigmaSE = zeros(niter, niter);
    OmegaSE = zeros(niter, niter);
    DeltaSE = zeros(niter, niter);
    PhiSE = zeros(niter, niter);    
    Bmat = zeros(niter, niter);    
    scalu = zeros(niter, 1);
    MSE = zeros(niter, 1);

    % initialization of SE recursion
    muSE(1) = alpha * eps;
    SigmaSE(1, 1) = freecum(2);
    betac(2) = muSE(1)/(muSE(1)^2+SigmaSE(1, 1));
    OmegaSE(1, 1) = 0;
    DeltaSE(1, 1) = 1;
    Bmat(1, 1) = freecum(1);
    
    scalu(1) = eps^2;
    MSE(1) = 1-eps^2;
    
    fprintf('Iteration %d, scal=%f, MSE=%f\n', 1, scalu(1), MSE(1));
        
    for t = 1 : niter-1
        
        if t == 1
            muprev = 0;
        else
            muprev = betac(t)*muSE(t-1);
        end
        
        muSE(t+1) = betac(t+1) * (Bmat(t, 1:t) * muSE(1:t) + alpha * muSE(t) - muprev);
        
        for j1 = 0 : t-1
            
            if t == 1
                Deltaprev = 0;
            else
                Deltaprev = betac(t) *DeltaSE(j1+1, t-1);
            end
            
            DeltaSE(t+1, j1+1) = betac(t+1) * (OmegaSE(t, j1+1) + Bmat(t, 1:t) * DeltaSE(j1+1, 1:t)' ...
                + muSE(t)*muSE(j1+1)/alpha - Deltaprev);
            DeltaSE(j1+1, t+1) = DeltaSE(t+1, j1+1);
        end

        if t == 1
            DeltaSE(t+1, t+1) = betac(t+1)^2 * (SigmaSE(t, t) + Bmat(t, 1:t) * DeltaSE(1:t, 1:t) * Bmat(t, 1:t)' ...
                + muSE(t)^2 + 2 * Bmat(t, 2:t) * OmegaSE(t, 2:t)' ...
                + 2 * muSE(t)/alpha * Bmat(t, 1:t) * muSE(1:t));
        else
            DeltaSE(t+1, t+1) = betac(t+1)^2 * (SigmaSE(t, t) + Bmat(t, 1:t) * DeltaSE(1:t, 1:t) * Bmat(t, 1:t)' ...
                + muSE(t)^2 + betac(t)^2 * DeltaSE(t-1, t-1) + 2 * Bmat(t, 2:t) * OmegaSE(t, 2:t)' ...
                - 2 * betac(t) * OmegaSE(t, t-1) + 2 * muSE(t)/alpha * Bmat(t, 1:t) * muSE(1:t) ...
                - 2 * Bmat(t, 1:t) * betac(t) * DeltaSE(t-1, 1:t)' - 2 * betac(t) * muSE(t) * muSE(t-1)/alpha);
        end
        
        PhiSE(t+1, t) = betac(t+1);
        
        if t > 1
            for j1 = 1 : t-1
                PhiSE(t+1, j1) = betac(t+1) * (Bmat(t, 2:t) * PhiSE(2:t, j1) - betac(t) * PhiSE(t-1, j1));
            end
        end
        
        Bnew = zeros(t+1, t+1);
        
        for j1 = 0 : t
            Bnew = Bnew + freecum(j1+1) * (PhiSE(1:t+1, 1:t+1))^j1;
        end
                
        Bmat(1:t+1, 1:t+1) = Bnew;
        
        Sigmanew = zeros(t+1, t+1);
        
        for j1 = 0 : 2*t
            Matadd = zeros(t+1, t+1);
            
            for j2 = 0 : j1
                Matadd = Matadd + (PhiSE(1:t+1, 1:t+1))^j2 * DeltaSE(1:t+1, 1:t+1) * (PhiSE(1:t+1, 1:t+1)')^(j1-j2);
            end
            Sigmanew = Sigmanew + freecum(j1+2) * Matadd;
        end
        
        if max(max(abs(Sigmanew(1:t, 1:t) - SigmaSE(1:t, 1:t)))) > 10^(-8)
            fprintf('Something wrong for SigmaSE\n');
        end
        
        SigmaSE(1:t+1, 1:t+1) = Sigmanew;
        
        for j1 = 1 : t
            if t == 1
                OmegaSE(j1, t+1) = betac(t+1) * (SigmaSE(t, j1) + Bmat(t, 2:t) * OmegaSE(j1, 2:t)');
            else
                OmegaSE(j1, t+1) = betac(t+1) * (SigmaSE(t, j1) + Bmat(t, 2:t) * OmegaSE(j1, 2:t)' - betac(t) * OmegaSE(j1, t-1));
            end
        end
        
        OmegaSE(t+1, 2) = betac(2) * SigmaSE(t+1, 1);
        for j1 = 2 : t
            OmegaSE(t+1, j1+1) = betac(j1+1) * (SigmaSE(t+1, j1) + Bmat(j1, 2:j1) * OmegaSE(t+1, 2:j1)' - betac(j1) * OmegaSE(t+1, j1-1));
        end
        
        tildeDelta = SigmaSE(t+1, t+1) + Bmat(t+1, 1:t+1) * DeltaSE(1:t+1, 1:t+1) * Bmat(t+1, 1:t+1)' ...
                + muSE(t+1)^2 + betac(t+1)^2 * DeltaSE(t, t) + 2 * Bmat(t+1, 2:t+1) * OmegaSE(t+1, 2:t+1)' ...
                - 2 * betac(t+1) * OmegaSE(t+1, t) + 2 * muSE(t+1)/alpha * Bmat(t+1, 1:t+1) * muSE(1:t+1) ...
                - 2 * Bmat(t+1, 1:t+1) * betac(t+1) * DeltaSE(t, 1:t+1)' - 2 * betac(t+1) * muSE(t+1) * muSE(t)/alpha;
        
        tildemu = sqrt(abs(tildeDelta-DeltaSE(t+1, t+1))); 
        betac(t+2) = tildemu/(tildemu^2+DeltaSE(t+1, t+1));
            
        scalu(t+1) = (muSE(t+1)/alpha)^2/DeltaSE(t+1, t+1);
        MSE(t+1) = (1 - 2 * (muSE(t+1)/alpha)^2 + DeltaSE(t+1, t+1)^2)/2;        
        fprintf('Iteration %d, scal=%f, MSE=%f\n', t+1, scalu(t+1), MSE(t+1));
        
    end
    
    
end