% State evolution (SE) recursion for the rectangular model (1.2) with
% uniform noise

clear;
close all;
clc;

% A few useful quantities have been pre-computed and stored in
% spect_teo_rect_unif.mat:
% \Delta_{\rm PCA} (DeltaPCAval)
% R-transform (Rtransval)
% derivative of the R-transform (Rtransval1)
load spect_teo_rect_unif.mat;

alphagridAMP = [0.6, 1]; % values of the SNR \alpha
niter = 10; % number of iterations of SE recursion
gamma = 1/2; % gamma=m/n

% These choices of alphagridAMP and niter ensure that the program 
% runs fast on a laptop. The parameters employed to obtain the results
% reported in Figures 1-2 are specified in Section 4 of the paper


% computation of the limit rectangular free cumulants

max_it = niter;

% the limit rectangular free cumulants have been pre-computed and stored in
% this .mat file as fractions (hence, infinite precision) 
load freerectcum_unif_100.mat;

freecum = zeros(1, 2*max_it);

% We have pre-computed the first 100 rectangular free cumulants
if niter > 50
    freecum(1:100) = kdouble';
else
    freecum = kdouble(1:2*max_it)';  
end
    
for j = 1 : length(alphagridAMP)
    
    alpha = alphagridAMP(j);     
    fprintf('alpha=%f\n', alpha);
    
    for i = 1 : length(alphagrid)
        if abs(alpha-alphagrid(i))<10^(-10)
            Rtransval = Rtrans(i);
            Rtransval1 = Rtrans1(i);
            DeltaPCAval = DeltaPCA(i);
        end
    end

    % allocate vectors for SE recursion
    muSE = zeros(niter, 1); % contains \mu_i
    nuSE = zeros(niter, 1); % contains \nu_i
    sigmaSE = zeros(niter, niter); % contains \sigma_{i, j}
    omegaSE = zeros(niter, niter); % contains \omega_{i, j}
    Muprod = zeros(niter, niter); % contains E[U_i U_j]
    Mvprod = zeros(niter, niter); % contains E[V_i V_j]
    avgder = zeros(niter, 1); % contains E[u_i'(F_{i-1})]
    scalu = zeros(niter, 1); % contains the (limit) normalized scalar product between u^* and AMP iteration u^i
    scalv = zeros(niter, 1); % contains the (limit) normalized scalar product between v^* and AMP iteration v^i

    % initialization of SE recursion
    nuSE(1) = alpha * sqrt(DeltaPCAval);
    muSE(1) = nuSE(1);
    omegaSE(1, 1) = -1/(gamma^2/alpha^4 * Rtransval1) * (DeltaPCAval * (1+gamma^2/alpha^2 * Rtransval1) + ...
        (gamma/alpha^2 * Rtransval1-Rtransval-1));
    Muprod(1, 1) = 1;
    Mvprod(1, 1) = gamma^2/alpha^2 * (nuSE(1)^2 + omegaSE(1, 1));
    
    % as a sanity check, we explicitly compute here \sigma_{1, 1}
    sigmaSE(1, 1) = freecum(1) * Mvprod(1, 1) + freecum(2) * Muprod(1, 1)  * gamma^2/alpha^2 + 2 * ...
        ( alpha^2/gamma * (Rtransval - freecum(1)*gamma/alpha^2) * Mvprod(1, 1)  + ...
        alpha^3/gamma^2 * (Rtransval - freecum(1)*gamma/alpha^2 -freecum(2)*(gamma/alpha^2)^2) * gamma^2/alpha ) ...
        + (Rtransval1 * (alpha^2/gamma)^2 - 2 * (alpha^2/gamma)^3 * Rtransval + freecum(1) * (gamma/alpha^2)^(-2)) ...
        * gamma^4/alpha^6 * (nuSE(1)^2 + omegaSE(1, 1)) + ...
         ( (alpha^2/gamma)^2 * (Rtransval1* alpha^2/gamma - (alpha^2/gamma)^2 * Rtransval ) ...
        - 2 * (alpha^2/gamma)^4 * (Rtransval - freecum(1) * gamma/alpha^2) + freecum(2) * (gamma/alpha^2)^(-2) ) ...
        * gamma^4/alpha^6 ;
    
    scalu(1) = sqrt(DeltaPCAval);
    scalv(1) = gamma/alpha * nuSE(1)/sqrt(Mvprod(1, 1));
        
    fprintf('Iteration 1, scalu=%f, scalv=%f\n', scalu(1), scalv(1));    

    % computation of E[u_2'(F_{1})]
    fun = @(x) 1/sqrt(2*pi) * exp(-x.^2/2) .* ...
        ( (tanh( muSE(1)^2/sigmaSE(1, 1) + muSE(1)/sqrt(sigmaSE(1, 1)) * x )).^2 + ...
        (tanh( - muSE(1)^2/sigmaSE(1, 1) + muSE(1)/sqrt(sigmaSE(1, 1)) * x )).^2 ) ;
    
    avgder(1) = muSE(1)/sigmaSE(1, 1) * ( 1 - 1/2 * integral(fun,-Inf,Inf));

        
    for jj = 2 : niter
        
        % computation of \nu_t and \mu_t        
        fun = @(x) 1/sqrt(2*pi) * exp(-x.^2/2) .* ...
            ( (tanh( muSE(jj-1)^2/sigmaSE(jj-1, jj-1) + muSE(jj-1)/sqrt(sigmaSE(jj-1, jj-1)) * x )) - ...
            (tanh( - muSE(jj-1)^2/sigmaSE(jj-1, jj-1) + muSE(jj-1)/sqrt(sigmaSE(jj-1, jj-1)) * x )) ) ;
        intf = integral(fun,-Inf,Inf);
        nuSE(jj) = alpha/2 * intf;
        muSE(jj) = alpha/gamma * nuSE(jj);
        
        % computation of E[U_t U_1]
        if jj == 2
        
            fun = @(x) 1/sqrt(2*pi) * exp(-x.^2/2) .* ...
                ( (tanh( muSE(1)^2/sigmaSE(1, 1) + muSE(1)/sqrt(sigmaSE(1, 1)) * x )) .* (muSE(1)+sqrt(sigmaSE(1, 1)) * x)/alpha + ...
                (tanh( -muSE(1)^2/sigmaSE(1, 1) + muSE(1)/sqrt(sigmaSE(1, 1)) * x )) .* (-muSE(1)+sqrt(sigmaSE(1, 1)) * x)/alpha );
            intf = integral(fun,-Inf,Inf);
        
            Muprod(jj, 1) = 1/2 * intf;        
            Muprod(1, jj) = Muprod(jj, 1);
            
        else
            
            Sigma = [sigmaSE(jj-1, jj-1), sigmaSE(jj-1, 1); ...
                    sigmaSE(jj-1, 1), sigmaSE(1, 1)];
            invS = inv(Sigma);
            
            % If we run the SE recursion for many iterations, because of
            % numerical issues, the matrix Sigma may have determinant < 0.
            % This is not possible since Sigma is a covariance matrix.  
            % Thus, in this case, we assume that Sigma has 0 determinant
            % and the 2D integral becomes a 1D integral
            
            if det(Sigma) < 0 
                fun = @(x) 1/sqrt(2*pi) * exp(-x.^2/2) .* ...
                    ( ( (tanh( muSE(jj-1)^2/sigmaSE(jj-1, jj-1) + muSE(jj-1)/sqrt(sigmaSE(jj-1, jj-1)) * x )) .* ... 
                    (muSE(1)+sqrt(sigmaSE(1, 1))*x)/alpha ) + ...
                    ( (tanh( - muSE(jj-1)^2/sigmaSE(jj-1, jj-1) + muSE(jj-1)/sqrt(sigmaSE(jj-1, jj-1)) * x )) .* ... 
                    (-muSE(1)+sqrt(sigmaSE(1, 1))*x)/alpha ) );
            
                Muprod(jj, 1) = 1/2 * integral(fun,-Inf,Inf);
                Muprod(1, jj) = Muprod(jj, 1);
            else
                       
                fun = @(x,y) 1/(2*pi*sqrt(det(Sigma))) * ...
                    exp(-1/2 * ( invS(1, 1) * x.^2 + invS(2, 2) * y.^2 + 2*invS(1, 2)*x.*y) ) .* ...
                    ( ( (tanh( muSE(jj-1)^2/sigmaSE(jj-1, jj-1) + muSE(jj-1)/sigmaSE(jj-1, jj-1) * x )) .* ... 
                    (muSE(1)+y)/alpha ) + ...
                    ( (tanh( - muSE(jj-1)^2/sigmaSE(jj-1, jj-1) + muSE(jj-1)/sigmaSE(jj-1, jj-1) * x )) .* ... 
                    (-muSE(1)+y)/alpha ) );
            
                Muprod(jj, 1) = 1/2 * integral2(fun,-Inf,Inf,-Inf,Inf);
                Muprod(1, jj) = Muprod(jj, 1);
            end   
        end
        
        % computation of E[U_t^2]
        fun = @(x) 1/sqrt(2*pi) * exp(-x.^2/2) .* ...
            ( (tanh( muSE(jj-1)^2/sigmaSE(jj-1, jj-1) + muSE(jj-1)/sqrt(sigmaSE(jj-1, jj-1)) * x )).^2 + ...
            (tanh( - muSE(jj-1)^2/sigmaSE(jj-1, jj-1) + muSE(jj-1)/sqrt(sigmaSE(jj-1, jj-1)) * x )).^2 ) ;   
        Muprod(jj, jj) = 1/2 * integral(fun,-Inf,Inf);

        % normalized correlation between u^t and the signal u^*        
        scalu(jj) = nuSE(jj)/alpha/sqrt(Muprod(jj, jj));

        % computation of E[U_i U_j] (all the remaining values)
        for ii = 2 : jj-1
                          
            Sigma = [sigmaSE(jj-1, jj-1), sigmaSE(jj-1, ii-1); ...
                    sigmaSE(jj-1, ii-1), sigmaSE(ii-1, ii-1)];
            invS = inv(Sigma);
            
            if det(Sigma) < 0 
                
                fun = @(x) 1/sqrt(2*pi) * exp(-x.^2/2) .* ...
                    ( ( (tanh( muSE(jj-1)^2/sigmaSE(jj-1, jj-1) + muSE(jj-1)/sqrt(sigmaSE(jj-1, jj-1)) * x )) .* ... 
                    (tanh( muSE(ii-1)^2/sigmaSE(ii-1, ii-1) + muSE(ii-1)/sqrt(sigmaSE(ii-1, ii-1)) * x )) ) + ...
                    ( (tanh( - muSE(jj-1)^2/sigmaSE(jj-1, jj-1) + muSE(jj-1)/sqrt(sigmaSE(jj-1, jj-1)) * x )) .* ... 
                    (tanh( - muSE(ii-1)^2/sigmaSE(ii-1, ii-1) + muSE(ii-1)/sqrt(sigmaSE(ii-1, ii-1)) * x )) ) );
                Muprod(jj, ii) = 1/2 * integral(fun,-Inf,Inf);
            else
                       
                fun = @(x,y) 1/(2*pi*sqrt(det(Sigma))) * ...
                    exp(-1/2 * ( invS(1, 1) * x.^2 + invS(2, 2) * y.^2 + 2*invS(1, 2)*x.*y) ) .* ...
                    ( ( (tanh( muSE(jj-1)^2/sigmaSE(jj-1, jj-1) + muSE(jj-1)/sigmaSE(jj-1, jj-1) * x )) .* ... 
                    (tanh( muSE(ii-1)^2/sigmaSE(ii-1, ii-1) + muSE(ii-1)/sigmaSE(ii-1, ii-1) * y )) ) + ...
                    ( (tanh( - muSE(jj-1)^2/sigmaSE(jj-1, jj-1) + muSE(jj-1)/sigmaSE(jj-1, jj-1) * x )) .* ... 
                    (tanh( - muSE(ii-1)^2/sigmaSE(ii-1, ii-1) + muSE(ii-1)/sigmaSE(ii-1, ii-1) * y )) ) );
            
                Muprod(jj, ii) = 1/2 * integral2(fun,-Inf,Inf,-Inf,Inf);
            end
            Muprod(ii, jj) = Muprod(jj, ii);
        end
        
        
        % computation of \omega_{i, j}        
        for ii = 1 : jj

            M = zeros(jj-1, ii-1);
                      
            for i1 = 1 : jj-1
                for i2 = 1 : ii-1       
                    
                   if i1 == jj-1
                       val1 = gamma/alpha;
                   else
                       val1 = 1;
                   end
                   
                   if i2 == ii-1
                       val2 = gamma/alpha;
                   else
                       val2 = 1;
                   end 
                    
                   M(i1, i2) = prod(avgder(ii-i2+1 : ii-1)) * prod(avgder(jj-i1+1 : jj-1)) ...
                       * ( freecum(i1+i2-1) * Muprod(jj-i1+1, ii-i2+1) + ...
                       freecum(i1+i2) * avgder(jj-i1) * avgder(ii-i2) * ...
                       val1 * val2 * (nuSE(jj-i1) * nuSE(ii-i2) + omegaSE(jj-i1, ii-i2)) );
                end
            end
            
            M1 = zeros(1, ii-1);
           
            for i2 = 1 : ii-1
                
               if i2 == ii-1
                   val2 = gamma/alpha;
               else
                   val2 = 1;
               end  
                
               M1(i2) = prod(avgder(1 : jj-1)) * prod(avgder(ii-i2+1 : ii-1)) * gamma/alpha ...
                   * ( (gamma/alpha^2)^(-jj-i2+1) * (Rtransval - sum( freecum(1:jj+i2-2) .* ((gamma/alpha^2).^(1:jj+i2-2)))) ...
                   * Muprod(1, ii-i2+1) + (gamma/alpha^2)^(-jj-i2) * (Rtransval - sum( freecum(1:jj+i2-1) .* ((gamma/alpha^2).^(1:jj+i2-1)))) ...
                   * gamma/alpha^2 * avgder(ii-i2) * val2 * (nuSE(1) * nuSE(ii-i2) + omegaSE(1, ii-i2)));
            end
            
            M2 = zeros(jj-1, 1);
           
            for i1 = 1 : jj-1
                
               if ii == 1
                   val = 1;
               else
                   val = gamma/alpha;
               end
               
               if i1 == jj-1
                   val1 = gamma/alpha;
               else
                   val1 = 1;
               end  
                
               M2(i1) = prod(avgder(1 : ii-1)) * prod(avgder(jj-i1+1 : jj-1)) * val ...
                   * ( (gamma/alpha^2)^(-ii-i1+1) * (Rtransval - sum( freecum(1:ii+i1-2) .* ((gamma/alpha^2).^(1:ii+i1-2)))) ...
                   * Muprod(1, jj-i1+1) + (gamma/alpha^2)^(-ii-i1) * (Rtransval - sum( freecum(1:ii+i1-1) .* ((gamma/alpha^2).^(1:ii+i1-1)))) ...
                   * gamma/alpha^2 * avgder(jj-i1) * val1 * (nuSE(1) * nuSE(jj-i1) + omegaSE(1, jj-i1)) ) ;
            end
        
            extrav1 = zeros(1, ii-1);
           
            for i2 = 1 : ii-1
               extrav1(i2) = Rtransval - sum( freecum(1:i2-1) .* ((gamma/alpha^2).^(1:i2-1)));
            end
            
            extrav2 = zeros(1, jj-1);
           
            for i1 = 1 : jj-1
               extrav2(i1) = Rtransval - sum( freecum(1:i1-1) .* ((gamma/alpha^2).^(1:i1-1)));
            end
           
            extrac = zeros(jj-1, ii-1);
           
            for i1 = 1 : jj-1
                for i2 = 1 : ii-1
                     extrac(i1, i2) = freecum(i1+i2-1) * ((gamma/alpha^2)^(i1+i2-2));
                end
            end
            
            if ii == 1
                val = 1;
            else
                val = gamma/alpha;
            end
           
            M3 = (alpha^2/gamma)^(jj+ii-2) * ( Rtransval1 - ...
               alpha^2/gamma*sum(extrav1) - alpha^2/gamma*sum(extrav2) + sum(sum(extrac)) ) * prod(avgder(1 : jj-1)) ...
                   * prod(avgder(1 : ii-1)) * val * gamma/alpha;
                       
            extrav1b = zeros(1, ii-1);
           
            for i2 = 1 : ii-1
               extrav1b(i2) = Rtransval - sum( freecum(1:i2) .* ((gamma/alpha^2).^(1:i2)));
            end
            
            extrav2b = zeros(1, jj-1);
           
            for i1 = 1 : jj-1
               extrav2b(i1) = Rtransval - sum( freecum(1:i1) .* ((gamma/alpha^2).^(1:i1)));
            end
           
            extracb = zeros(jj-1, ii-1);
           
            for i1 = 1 : jj-1
                for i2 = 1 : ii-1
                     extracb(i1, i2) = freecum(i1+i2) * ((gamma/alpha^2)^(i1+i2-2));
                end
            end
           
            M3b = (alpha^2/gamma)^(jj+ii-2) * ( Rtransval1 *alpha^2/gamma - (alpha^2/gamma)^2 * Rtransval -...
               (alpha^2/gamma)^2*sum(extrav1b) - (alpha^2/gamma)^2*sum(extrav2b) + sum(sum(extracb)) ) * prod(avgder(1 : jj-1)) ...
                   * prod(avgder(1 : ii-1)) * val * gamma^3/alpha^5 * (omegaSE(1, 1)+nuSE(1)^2);
                       
            omegaSE(jj, ii) = gamma * (sum(sum(M)) + sum(M1) + sum(M2) + M3 + M3b);
            omegaSE(ii, jj) = omegaSE(jj, ii);
           
        end

        % normalized correlation between v^t and the signal v^* 
        scalv(jj) = muSE(jj)/(alpha/gamma * sqrt(omegaSE(jj, jj)+nuSE(jj)^2));
        fprintf('Iteration %d, scalu=%f, scalv=%f\n', jj, scalu(jj), scalv(jj));

        % computation of \sigma_{i, j}
        for ii = 1 : jj

            M = zeros(jj, ii);
                      
            for i1 = 1 : jj
                for i2 = 1 : ii 
                    
                   if i1 == jj
                       val1 = gamma/alpha;
                   else
                       val1 = 1;
                   end
                   
                   if i2 == ii
                       val2 = gamma/alpha;
                   else
                       val2 = 1;
                   end  
                    
                   M(i1, i2) = prod(avgder(ii-i2+1 : ii-1)) * prod(avgder(jj-i1+1 : jj-1)) ...
                       * ( freecum(i1+i2-1) * val1 * val2 * ( nuSE(jj-i1+1)*nuSE(ii-i2+1) + omegaSE(jj-i1+1, ii-i2+1) ) + ...
                       freecum(i1+i2) * val1 * val2 * Muprod(jj-i1+1, ii-i2+1) );
                end
            end
                        
            M1 = zeros(1, ii);
           
            for i2 = 1 : ii 
                
               if i2 == ii
                   val2 = gamma/alpha;
               else
                   val2 = 1;
               end   
                
               M1(i2) = prod(avgder(1 : jj-1)) * prod(avgder(ii-i2+1 : ii-1)) * gamma/alpha ...
                   * ( (gamma/alpha^2)^(-jj-i2+1) * (Rtransval - sum( freecum(1:jj+i2-1) .* ((gamma/alpha^2).^(1:jj+i2-1)))) ...
                   * val2 * (nuSE(ii-i2+1)*nuSE(1)+ omegaSE(1, ii-i2+1)) + ...
                   (gamma/alpha^2)^(-jj-i2) * (Rtransval - sum( freecum(1:jj+i2) .* ((gamma/alpha^2).^(1:jj+i2)))) ...
                   * val2 * Muprod(1, ii-i2+1) );
            end
            
            M2 = zeros(jj, 1);
           
            for i1 = 1 : jj
                
               if i1 == jj
                   val1 = gamma/alpha;
               else
                   val1 = 1;
               end  
                
               M2(i1) = prod(avgder(1 : ii-1)) * prod(avgder(jj-i1+1 : jj-1)) * gamma/alpha ...
                   * ( (gamma/alpha^2)^(-ii-i1+1) * (Rtransval - sum( freecum(1:ii+i1-1) .* ((gamma/alpha^2).^(1:ii+i1-1)))) ...
                   * val1 * (nuSE(jj-i1+1)*nuSE(1)+ omegaSE(1, jj-i1+1)) + ...
                   (gamma/alpha^2)^(-ii-i1) * (Rtransval - sum( freecum(1:ii+i1) .* ((gamma/alpha^2).^(1:ii+i1)))) ...
                   * val1 * Muprod(1, jj-i1+1) );
            end
        
            extrav1 = zeros(1, ii);
           
            for i2 = 1 : ii
               extrav1(i2) = Rtransval - sum( freecum(1:i2-1) .* ((gamma/alpha^2).^(1:i2-1)));
            end
            
            extrav2 = zeros(1, jj);
           
            for i1 = 1 : jj
               extrav2(i1) = Rtransval - sum( freecum(1:i1-1) .* ((gamma/alpha^2).^(1:i1-1)));
            end
           
            extrac = zeros(jj, ii);
           
            for i1 = 1 : jj
                for i2 = 1 : ii
                     extrac(i1, i2) = freecum(i1+i2-1) * ((gamma/alpha^2)^(i1+i2-2));
                end
            end
           
            M3 = (alpha^2/gamma)^(jj+ii) * ( Rtransval1 - ...
               alpha^2/gamma*sum(extrav1) - alpha^2/gamma*sum(extrav2) + sum(sum(extrac)) ) * prod(avgder(1 : jj-1)) ...
                   * prod(avgder(1 : ii-1)) * gamma^4/alpha^6 * (omegaSE(1, 1)+nuSE(1)^2);
                       
            extrav1b = zeros(1, ii);
           
            for i2 = 1 : ii
               extrav1b(i2) = Rtransval - sum( freecum(1:i2) .* ((gamma/alpha^2).^(1:i2)));
            end
            
            extrav2b = zeros(1, jj);
           
            for i1 = 1 : jj
               extrav2b(i1) = Rtransval - sum( freecum(1:i1) .* ((gamma/alpha^2).^(1:i1)));
            end
           
            extracb = zeros(jj, ii);
           
            for i1 = 1 : jj
                for i2 = 1 : ii
                     extracb(i1, i2) = freecum(i1+i2) * ((gamma/alpha^2)^(i1+i2-2));
                end
            end
           
            M3b = (alpha^2/gamma)^(jj+ii) * ( Rtransval1 *alpha^2/gamma - (alpha^2/gamma)^2 * Rtransval -...
               (alpha^2/gamma)^2*sum(extrav1b) - (alpha^2/gamma)^2*sum(extrav2b) + sum(sum(extracb)) ) * prod(avgder(1 : jj-1)) ...
                   * prod(avgder(1 : ii-1)) * gamma^4/alpha^6;
                       
            sigmaSE(jj, ii) = sum(sum(M)) + sum(M1) + sum(M2) + M3 + M3b;
            sigmaSE(ii, jj) = sigmaSE(jj, ii);
           
        end
        
        % computation of E[u_t'(F_{t-1})]        
        fun = @(x) 1/sqrt(2*pi) * exp(-x.^2/2) .* ...
            ( (tanh( muSE(jj)^2/sigmaSE(jj, jj) + muSE(jj)/sqrt(sigmaSE(jj, jj)) * x )).^2 + ...
            (tanh( - muSE(jj)^2/sigmaSE(jj, jj) + muSE(jj)/sqrt(sigmaSE(jj, jj)) * x )).^2 ) ;
        avgder(jj) = muSE(jj)/sigmaSE(jj, jj) * ( 1 - 1/2 * integral(fun,-Inf,Inf));
        
        
    end
    
end
