function xsol = radio_interferometric_calibration(N, K)
% Returns the gain matrices of N stations with K receivers.
%
% function xsol = radio_interferometric_calibration(N, K)
%
% N >= K is always assumed.
%
% The example considers calibration of an array of N stations.
% We simulate a system with N stations, each having K receivers.
% For radio astronomy, K = 2.
%
% For a detailed exposition of the problem at hand, refer to the paper:
% "Radio interferometric calibration using a Riemannian manifold",
% Sarod Yatawatta, ICASSP, 2013.
% Available at http://dx.doi.org/10.1109/ICASSP.2013.6638382.
%
% The source of the signal is unpolarized (given by the matrix C).
% The measured data is the cross correlation of the signals at each receiver.
% So there will be N(N-1)/2 possible cross correlations.
% Noise with given SNR is added to the signal.
%
% The objective is to estimate the gains of each receiver (K x K) matrix,
% so the total size of the solutions is N x (K x K), which is written
% as an NK x K matrix.
%
% Note: each station gain matrix (KxK) can have a KxK unitary ambiguity,
% therefore we use the quotient manifold structure. The unitary ambiguity 
% is common to all stations, so the solution obtained by 
% optimization routine always has an unkown unitary matrix that makes the 
% solution different from the true solution.
%

% This file is part of Manopt: www.manopt.org.
% Original author: Sarod Yatawatta, June 29, 2015.
% Contributors: Bamdev Mishra.
% Change log:
%    
%   June 28, 2016 (BM):
%       Modified the egrad and ehess operations according to 
%       the modified metric in the symfixedrankYYcomplexfactory file, 
%       where a factor of 2 was removed from the metric. Accordingly, 
%       a factor of 2 was added to egrad and ehess operations.
%   Aug  31, 2021 (XJ):
%       Added AD to compute the egrad and the ehess

    % Generate some random data to test the function
    
    if ~exist('N', 'var') || isempty(N)
        N = 10; 
    end
    if ~exist('K', 'var') || isempty(K)
        K = 2; 
    end
    
    assert(N >= K, 'N must be larger than or equal to K.');
    
    % Baselines (pairs of correlations)
    B = N*(N-1)/2;
    
    
    
    % Source coherence, at phase center
    C = eye(K);
    
    % Random J (gains) of all stations
    J = 0.2*rand(K*N,K) + 1i*rand(K*N,K);
 
    % Visibilities (cross correlations)
    V = zeros(K*B,K);
    
    ck = 1;
    for ci = 1 : N -1
        for cj = ci + 1 : N
            % Compute cross correlation of each receiver pair.
            V(K*(ck-1)+1:K*ck,:) = J(K*(ci-1)+1:K*ci,:)*C*J(K*(cj-1)+1:K*cj,:)';
            ck = ck + 1;
        end
    end
    
    % Generate noise
    SNR = 10000;% inf;
    nn = randn(K*B,K)+1i*randn(K*B,K);
    noise_var = norm(V)^2/(norm(nn)^2*SNR);
    nn = nn*sqrt(noise_var);
    
    % Add noise to signal
    V = V + nn;
    
    
    % Optimization part by creating the problem structure.
    % First, we use the manifold desctription.
    % Second, we define the problem cost, gradient and Hessian functions.
   
    
    % Manifold description
    % Note that the actual dimension is KN x K.
    problem.M = symfixedrankYYcomplexfactory(K*N, K);
    
    
    % Cost function
    problem.cost = @cost;
    function fval = cost(x)
        fval = 0.0;
        ck = 1;
        for p = 1 : N - 1
            for q = p + 1 : N
                res = V(K*(ck-1)+1:K*ck,:) - x(K*(p-1)+1:K*p,:)*C*x(K*(q-1)+1:K*q,:)'; % Residual
                fval = fval + real(res(:)'*res(:)); % Add norm of the residual.
                ck = ck + 1;
            end
        end
    end
    
    % Euclidean gradient of the cost function.
    % Manopt automatically converts it to the Riemannian couterpart.
    % The code involves for-loops for readability, but could be vectorized
    % for improved speed.
    problem.egrad = @egrad;
    function grad = egrad(x)
        grad = zeros(K*N, K);
        ck = 1;
        for p = 1 : N - 1
            for q = p+1 : N
                res = 2*(V(K*(ck-1)+1:K*ck,:) - x(K*(p-1)+1:K*p,:)*C*x(K*(q-1)+1:K*q,:)'); % Residual
                grad(K*(p-1)+1:K*p,:) = grad(K*(p-1)+1:K*p,:) - res*x(K*(q-1)+1:K*q,:)*C';
                grad(K*(q-1)+1:K*q,:) = grad(K*(q-1)+1:K*q,:) - res'*x(K*(p-1)+1:K*p,:)*C;
                ck = ck + 1;
            end
        end
    end
    
    % Euclidean Hessian of the cost function along a search direction eta.
    % Manopt automatically converts it to the Riemannian couterpart.
    problem.ehess = @ehess;
    function hess = ehess(x, eta)
        hess = zeros(K*N, K);
        ck = 1;
        for p = 1 : N-1
            for q = p+1:N
                res = 2*(V(K*(ck-1)+1:K*ck,:) -x(K*(p-1)+1:K*p,:)*C*x(K*(q-1)+1:K*q,:)'); % Residual
                resdot = 2*(-x(K*(p-1)+1:K*p,:)*C*eta(K*(q-1)+1:K*q,:)'  - eta(K*(p-1)+1:K*p,:)*C*x(K*(q-1)+1:K*q,:)'); % Residual derivative
                
                hess(K*(p-1)+1:K*p,:) = hess(K*(p-1)+1:K*p,:) - (res*eta(K*(q-1)+1:K*q,:) + resdot*x(K*(q-1)+1:K*q,:))*C';
                hess(K*(q-1)+1:K*q,:) = hess(K*(q-1)+1:K*q,:) - (res'*eta(K*(p-1)+1:K*p,:) + resdot'*x(K*(p-1)+1:K*p,:))*C;
                ck = ck + 1;
            end
        end
    end
    

    % An alternative way to compute the egrad and the ehess is to use 
    % automatic differentiation provided in the deep learning toolbox.
    % Notice that the for loop in the cost function can make AD much
    % slower. Call manoptAD to prepare AD for the problem structure.
    % problem = manoptAD(problem);


    % Execute some checks on the derivatives for early debugging.
    % checkgradient(problem);
    % pause;
    % checkhessian(problem);
    % pause;
    
    
    % Solve.
    [xsol,  xcost,  info] = trustregions(problem); 
    fprintf('Final cost: %g.\n', xcost);
    
    
    % Display some statistics.
    fs = 11;
    figure;
    semilogy([info.iter], [info.gradnorm], 'o-.','Color','blue', 'MarkerSize',6, 'LineWidth',1.1);
    ax1 = gca;
    set(ax1,'FontSize',fs);
    xlabel(ax1, 'Iteration #', 'FontSize',fs);
    ylabel(ax1, 'Gradient norm', 'FontSize',fs);
    title('Convergence of the trust-regions algorithm');

    % Make a plot of estimation error (only for K = 2).
    if K == 2
        % Find unitary ambiguity first by solving min ||J - xsol U||.
        % This has a closed-form solution.
        [u, ignore, v] = svd(xsol'*J); %#ok<ASGLU>

        % Error in position
        E = J - xsol*u*v'; 

        % Normalize error
        E = E/norm(J);

        % Plot
        figure;
        ax1 = subplot(1,2,1);
        quiver(real(J(:,1)), imag(J(:,1)),real(E(:,1)),imag(E(:,1)));
        hold all;
        scatter(real(J(:,1)), imag(J(:,1)));
        set(ax1,'FontSize',fs);
        xlabel('Real E_1');
        ylabel('Imag E_1');
        title('Position error 1st coordinate'); 
        axis equal;
        ax2 = subplot(1,2,2);
        quiver(real(J(:,2)),imag(J(:,2)),real(E(:,2)),imag(E(:,2)));
        hold all;
        scatter(real(J(:,2)),imag(J(:,2)));
        set(ax2,'FontSize',fs);
        xlabel('Real E_2');
        ylabel('Imag E_2');
        title('Position error 2nd coordinate'); 
        axis equal;
    end
    
end
