% AMP with PCA initialization for the rectangular model (1.2) with
% uniform noise

clear;
close all;
clc;

% contains some values that will be useful later
load spect_teo_rect_unif.mat;

alphagridAMP = [0.6, 1]; % values of the SNR \alpha
n = 2000; % length of the signal v^*
gamma = 1/2; % gamma=m/n
m = floor(n*gamma); % length of the signal u^*
ntrials = 2; % number of Montecarlo trials
niter = 10; % number of iterations of AMP with PCA initialization

% These choices of alphagrid, n, ntrials, niter ensure that the program 
% runs fast on a laptop. The parameters employed to obtain the results
% reported in Figures 1-2 are specified in Section 4 of the paper


% computation of the limit rectangular free cumulants

max_it = niter;

% the limit rectangular free cumulants have been pre-computed and stored in
% this .mat file as fractions (hence, infinite precision) 
load freerectcum_unif_100.mat;

freecum = zeros(1, 2*max_it);

% We have pre-computed the first 100 rectangular free cumulants
if niter > 50
    freecum(1:100) = kdouble';
else
    freecum = kdouble(1:2*max_it)';   
end

for j = 1 : length(alphagridAMP)
    
    alpha = alphagridAMP(j);

    for i = 1 : ntrials
        
        fprintf('alpha=%f, trial #%d\n', alpha, i);

        % The signal u has a Rademacher prior, i.e., its entries are i.i.d.
        % and uniform in {-1, 1}
        u = sign(rand(m, 1)-0.5);
        normu = sqrt(sum(u.^2));
        u = sqrt(m)*u/normu;

        % The signal v has a Gaussian prior, i.e., it is uniformly
        % distributed on the sphere of radius \sqrt{n}
        v = randn(n, 1);
        normv = sqrt(sum(v.^2));
        v = sqrt(n)*v/normv;
        
        % The eigenvalues of W W^T are i.i.d. and uniformly distributed
        % in the interval [0, 1]
        Lambda = [diag(sqrt(rand(m, 1))), zeros(m, n-m)];
        
        % the SVD of a matrix with Gaussian entries gives Haar-distributed
        % matrices
        G = randn(m,n);
        [U, S, V] = svd(G);
        
        X = alpha/m * (u * v') + U * Lambda * V';

        % initialization with PCA
        [Uc, Sc, Vc] = svd(X);
        Vr = real(Vc);
        Sr = real(Sc);
        Ur = real(Uc);
        [valeig, indeig] = sort(diag(Sr), 'descend');
        v0 = Vr(:, indeig(1));
        u0 = Ur(:, indeig(1));
        
        % computation of alphabar, which is the estimate of the SNR \alpha 
        % obtained from the largest singular value of the data matrix X
        if valeig(1) > 1
            
            zval = gamma * valeig(1)^2 * (log(valeig(1)^2/(valeig(1)^2-1)))^2 + ...
                (1-gamma) * log(valeig(1)^2/(valeig(1)^2-1));
            alphabar = sqrt(gamma/zval);
            
            invD = valeig(1);

            % estimation from the data of a few useful quantities: the
            % D-transform (Dder), \Delta_{\rm PCA} (DeltaPCAval), the
            % R-transform (Rtransval) and the derivative of the R-transform
            % (Rtransval1)            
            Dder = 2*gamma* invD * log(invD^2/(invD^2-1)) * ( log(invD^2/(invD^2-1)) - ...
            2/(invD^2-1)) - 2 * (1-gamma) / (invD*(invD^2-1));
        
            DeltaPCAval = - 2*zval* (invD * log(invD^2/(invD^2-1)))/Dder;
        
            Rtransval = (-gamma-1+sqrt((gamma+1)^2+4*gamma*(invD^2*zval-1)))/(2*gamma);
            Rtransval1 = (invD^2+2*zval*invD/Dder)/sqrt((gamma+1)^2+4*gamma*(invD^2*zval-1));
            
        else    
            
            % if we are too close to the PCA threshold, the formulas above
            % do not work well. Thus, we just set alphabar to be some
            % small value (e.g., 0.05), and the related quantities
            % (\Delta_{\rm PCA}, R-transform and its derivative) are
            % obtained from pre-computed values stored in
            % spect_teo_rect_unif.mat
            alphabar = 0.05; 
            Rtransval = Rtrans(1);
            Rtransval1 = Rtrans1(1);
            DeltaPCAval = DeltaPCA(1);
        end
            
        % allocate vectors for AMP iterations
        uAMP = zeros(m, niter+1);
        fAMP = zeros(m, niter);
        vAMP = zeros(n, niter);
        muSE = zeros(niter, 1);
        nuSE = zeros(niter, 1);
        sigmaSE = zeros(niter, 1);
        avgder = zeros(niter, 1);
        scalu = zeros(niter, 1);
        scalv = zeros(niter, 1);
        
        
        % first iteration of AMP with PCA initialization
        f0 = alphabar * sqrt(m) * u0;
        uAMP(:, 1) = sqrt(m) * u0;
        g0 = X' * uAMP(:, 1)/(1 + gamma * Rtransval);
        vAMP(:, 1) = gamma/alphabar * g0;
        
        % estimate SE parameters from their limit expressions (using the
        % value of \alpha and \Delta_{\rm PCA} obtained from the data)
        nuSE(1) = alphabar * sqrt(DeltaPCAval);
        muSE(1) = nuSE(1);

        scalu(1) = (sum(uAMP(:, 1).* u))^2/sum(u.^2)/sum(uAMP(:, 1).^2);
        scalv(1) = (sum(vAMP(:, 1).* v))^2/sum(v.^2)/sum(vAMP(:, 1).^2);
        fprintf('Iteration %d, scalu=%f, scalv=%f\n', 1, scalu(1), scalv(1));
                
        % iterate f^1
        fAMP(:, 1) = X * vAMP(:, 1) - alphabar * Rtransval * uAMP(:, 1);

        sigmaSE(1) = abs(fAMP(:, 1)' * fAMP(:, 1)/m - muSE(1)^2);
        
        % iterate u^2
        uAMP(:, 2) = tanh(muSE(1)/sigmaSE(1)*fAMP(:, 1));
        
        % normalized squared correlation between u^2 and the signal u^*
        scalu(2) = (sum(uAMP(:, 2).* u))^2/sum(u.^2)/sum(uAMP(:, 2).^2);
        
        % average of the derivative of the denoiser u_2
        avgder(1) = muSE(1)/sigmaSE(1) * ( 1 - mean((tanh(muSE(1)/sigmaSE(1)*fAMP(:, 1))).^2));
        
        for jj = 2 : niter-1
            
           % computation of the memory coefficients 
           a = zeros(1, jj);
           b = zeros(1, jj-1);
           a(jj)=freecum(1);
           for ii = 0 : jj-3
               a(jj-ii-1) = freecum(ii+2) * prod(avgder(jj-ii-1 : jj-1));
               b(jj-ii-1) = gamma * freecum(ii+1) * prod(avgder(jj-ii-1 : jj-1));
           end
           a(1) = (alphabar^2/gamma)^jj * (Rtransval - sum( freecum(1:jj-1) .* ((gamma/alphabar^2).^(1:jj-1)))) * prod(avgder(1 : jj-1)) * gamma/alphabar;
           b(1) = ( (alphabar^2/gamma)^jj * (Rtransval - sum( freecum(1:jj-1) .* ((gamma/alphabar^2).^(1:jj-1))))*gamma/alphabar^2 + freecum(jj-1) ) ...
               * gamma * prod(avgder(1 : jj-1));
           
           % iterate v^t
           vAMP(:, jj) = X' * uAMP(:, jj) - sum(repmat(b, n, 1) .* vAMP(:, 1:jj-1), 2);

           % iterate f^t
           fAMP(:, jj) = X * vAMP(:, jj) - sum(repmat(a, m, 1) .* uAMP(:, 1:jj), 2);
           
           % normalized squared correlation between v^t and the signal v^*
           scalv(jj) = (sum(vAMP(:, jj).* v))^2/sum(v.^2)/sum(vAMP(:, jj).^2);
           fprintf('Iteration %d, scalu=%f, scalv=%f\n', jj, scalu(jj), scalv(jj));
           
           % estimation of \sigma_{t, t} from the data
           M = zeros(jj, jj);
           
           for i1 = 1 : jj
               for i2 = 1 : jj
                   
                   if i1 == jj
                       der1 = gamma/alphabar;
                   else
                       der1 = 1;
                   end
                   
                   if i2 == jj
                       der2 = gamma/alphabar;
                   else
                       der2 = 1;
                   end
                   
                   M(i1, i2) = (freecum(i1+i2) * prod(avgder(jj-i1+1 : jj-1)) ...
                       * prod(avgder(jj-i2+1 : jj-1)) * der1 * der2 * (uAMP(:, jj-i1+1)' * uAMP(:, jj-i2+1))/m) ...
                       + (freecum(i1+i2-1) * prod(avgder(jj-i1+1 : jj-1)) ...
                       * prod(avgder(jj-i2+1 : jj-1)) * (vAMP(:, jj-i1+1)' * vAMP(:, jj-i2+1))/n);
                   
               end
           end
           
           M1 = zeros(1, jj);
           
           for i2 = 1 : jj
               
               
               if i2 == jj
                   der2 = gamma/alphabar;
               else
                   der2 = 1;
               end
               
               M1(1, i2) = gamma/alphabar * prod(avgder(1 : jj-1)) * prod(avgder(jj-i2+1 : jj-1)) * ( ( ... 
                   (gamma/alphabar^2)^(-jj-i2+1) * (Rtransval - sum( freecum(1:jj+i2-1) .* ((gamma/alphabar^2).^(1:jj+i2-1)))) ...
                   * (g0' * vAMP(:, jj-i2+1))/n ) + ( ... 
                   (gamma/alphabar^2)^(-jj-i2)/alphabar * (Rtransval - sum( freecum(1:jj+i2) .* ((gamma/alphabar^2).^(1:jj+i2)))) ...
                   * der2 * (f0' * uAMP(:, jj-i2+1))/m ) );
           end
           
           extrav = zeros(1, jj);
           
           for i1 = 1 : jj
               extrav(i1) = (Rtransval - sum( freecum(1:i1-1) .* ((gamma/alphabar^2).^(1:i1-1)) ) );
           end
           
           extrac = zeros(jj, jj);
           
           for i1 = 1 : jj
               for i2 = 1 : jj
                    extrac(i1, i2) = freecum(i1+i2-1) *(gamma/alphabar^2)^(i1+i2-2);
               end
           end
           
           M2 = (gamma/alphabar^2)^(-2*jj)* (Rtransval1 - ...
               2*alphabar^2/gamma*sum(extrav) + sum(sum(extrac)) ) * gamma^4/alphabar^6 * (prod(avgder(1 : jj-1)))^2 ...
                   * (g0' * g0)/n;

           extravb = zeros(1, jj);
           
           for i1 = 1 : jj
               extravb(i1) = (Rtransval - sum( freecum(1:i1) .* ((gamma/alphabar^2).^(1:i1)) ) );
           end
           
           extracb = zeros(jj, jj);
           
           for i1 = 1 : jj
               for i2 = 1 : jj
                    extracb(i1, i2) = freecum(i1+i2) *(gamma/alphabar^2)^(i1+i2-2);
               end
           end
           
           M2b = (gamma/alphabar^2)^(-2*jj)* (Rtransval1 * alphabar^2/gamma - (alphabar^2/gamma)^2 * Rtransval - ...
               2*(alphabar^2/gamma)^2*sum(extravb) + sum(sum(extracb)) ) * gamma^4 /alphabar^8 * (prod(avgder(1 : jj-1)))^2 ...
                   * (f0' * f0)/m;
           
           sigmaSE(jj) = sum(sum(M))+2*sum(M1)+M2+M2b;
           
           % given \sigma_{t, t}, \mu_t can be estimated from the norm of
           % the iterate f^t           
           muSE(jj) = sqrt(abs(sum(fAMP(:, jj).^2)/m - sigmaSE(jj)));
           
           % average of the derivative of the denoiser u_{t+1}
           avgder(jj) = muSE(jj)/sigmaSE(jj) * ( 1 - mean((tanh(muSE(jj)/sigmaSE(jj)*fAMP(:, jj))).^2));

           % iterate u^{t+1}          
           uAMP(:, jj+1) = tanh(muSE(jj)/sigmaSE(jj)*fAMP(:, jj));
           
           % normalized squared correlation between u^{t+1} and the signal 
           % u^*           
           scalu(jj+1) = (sum(uAMP(:, jj+1).* u))^2/sum(u.^2)/sum(uAMP(:, jj+1).^2);
           
        end        
    end
end
