function M = stiefelcomplexfactory(n, p, k)
% Returns a manifold struct. to optimize over complex orthonormal matrices.
%
% function M = stiefelcomplexfactory(n, p)
% function M = stiefelcomplexfactory(n, p, k)
%
% The complex Stiefel manifold is the set of complex orthonormal nxp
% matrices. If k is larger than 1, this is the Cartesian product of the
% complex Stiefel manifold taken k times. The metric is such that the
% manifold is a Riemannian submanifold of C^nxp equipped with the usual
% real-trace inner product, that is, it is the usual metric for the complex
% plane identified with R^2.
%
% Points are represented as matrices X of size n x p x k (or n x p if k=1,
% which is the default) such that each complex n x p matrix is orthonormal,
% i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for
% i = 1 : k if k > 1. Tangent vectors are represented as matrices the same
% size as points.
%
% By default, k = 1.
%
%
% Please cite the Manopt paper as well as either of these research papers
% pertaining to this specific geometry:
% @InProceedings{sato2013complex,
%   Title        = {A complex singular value decomposition algorithm based on the {R}iemannian {N}ewton method},
%   Author       = {Sato, H. and Iwai, T.},
%   Booktitle    = {Decision and Control ({CDC}), 2013 {IEEE} 52nd Annual Conference on},
%   Year         = {2013},
%   Organization = {IEEE},
%   Pages        = {2972--2978}
% }
% @InProceedings{sato2014Riemannian,
%   Title        = {{R}iemannian conjugate gradient method for complex singular value decomposition problem},
%   Author       = {Sato, H.},
%   Booktitle    = {Decision and Control ({CDC}), 2014 {IEEE} 53rd Annual Conference on},
%   Year         = {2014},
%   Organization = {IEEE},
%   Pages        = {5849--5854}
% }
%
%
% See also: stiefelfactory

% This file is part of Manopt: www.manopt.org.
% Original author: Hiroyuki Sato, April 27, 2015.
% Contributors: 
% Change log: 
%   June 18, 2019 (NB) : Using qr_unique for retr and rand.
    
    if ~exist('k', 'var') || isempty(k)
        k = 1;
    end
    
    if k == 1
        M.name = @() sprintf('Complex Stiefel manifold St(%d, %d)', n, p);
    elseif k > 1
        M.name = @() sprintf('Product complex Stiefel manifold St(%d, %d)^%d', n, p, k);
    else
        error('k must be an integer no less than 1.');
    end
    
    M.dim = @() k*(2*n*p - p^2); %! k*(n*p - .5*p*(p+1)) -> k*(2*n*p - p^2)
    
    M.inner = @(x, d1, d2) real(d1(:)'*d2(:)); %! trace -> real-trace
    
    M.norm = @(x, d) norm(d(:));
    
    M.dist = @(x, y) error('stiefel.dist not implemented yet.');
    
    M.typicaldist = @() sqrt(p*k);
    
    M.proj = @projection;
    function Up = projection(X, U)
        
        XHU = multiprod(multihconj(X), U); %! XtU -> XHU, multitransp -> multihconj
        herXHU = multiherm(XHU); %! symXtU -> herXHU, multisym -> multiherm
        Up = U - multiprod(X, herXHU); %! symXtU -> herXHU
        
    end
    
    M.tangent = M.proj;
    
    % For Riemannian submanifolds, converting a Euclidean gradient into a
    % Riemannian gradient amounts to an orthogonal projection.
    M.egrad2rgrad = M.proj;
    
    M.ehess2rhess = @ehess2rhess;
    function rhess = ehess2rhess(X, egrad, ehess, H)
        XHG = multiprod(multihconj(X), egrad); %! XtG -> XHG, multitransp -> multihconj
        herXHG = multiherm(XHG); %! symXtG -> herXHG, multisym(XtG) -> multiherm(XHG)
        HherXHG = multiprod(H, herXHG); %! HsymXtG -> HherXHG, symXtG -> herXHG
        rhess = projection(X, ehess - HherXHG); %! HsymXtG -> HherXHG
    end
    
    M.retr = @retraction;
    function Y = retraction(X, U, t)
        % It is necessary to call qr_unique rather than simply qr to ensure
        % this is a retraction, to avoid spurious column sign flips.
        % This is only a first-order retraction.
        if nargin < 3
            Y = qr_unique(X + U);
        else
            Y = qr_unique(X + t*U);
        end
    end
    
    M.exp = @exponential;
    function Y = exponential(X, U, t)
        if nargin == 2
            t = 1;
        end
        tU = t*U;
        Y = zeros(size(X));
        for i = 1 : k
            % From a formula by Ross Lippert, Example 5.4.2 in AMS08.
            Xi = X(:, :, i);
            Ui = tU(:, :, i);
            Y(:, :, i) = [Xi Ui] * ...
                         expm([Xi'*Ui , -Ui'*Ui ; eye(p) , Xi'*Ui]) * ...
                         [ expm(-Xi'*Ui) ; zeros(p) ];
        end
        
    end

    M.hash = @(X) ['z' hashmd5([real(X(:)) ; imag(X(:))])]; %! X(:) -> [real(X(:)) ; imag(X(:))]
    
    M.rand = @() qr_unique(randn(n, p, k) + 1i*randn(n, p, k));
    
    M.randvec = @randomvec;
    function U = randomvec(X)
        U = projection(X, randn(n, p, k) + 1i*randn(n, p, k)); %! Complex version
        U = U / norm(U(:));
    end
    
    M.lincomb = @matrixlincomb;
    
    M.zerovec = @(x) zeros(n, p, k);
    
    M.transp = @(x1, x2, d) projection(x2, d);
    
    M.vec = @(x, u_mat) [real(u_mat(:)) ; imag(u_mat(:))];
    M.mat = @(x, u_vec) reshape(u_vec(1:(n*p*k)) + 1i*u_vec((n*p*k+1):end), [n, p, k]);
    M.vecmatareisometries = @() true;

end
