function M = stiefelfactory(n, p, k, gpuflag)
% Returns a manifold structure to optimize over orthonormal matrices.
%
% function M = stiefelfactory(n, p)
% function M = stiefelfactory(n, p, k)
% function M = stiefelfactory(n, p, k, gpuflag)
%
% The Stiefel manifold is the set of orthonormal nxp matrices. If k
% is larger than 1, this is the Cartesian product of the Stiefel manifold
% taken k times. The metric is such that the manifold is a Riemannian
% submanifold of R^nxp equipped with the usual trace inner product, that
% is, it is the usual metric.
%
% Points are represented as matrices X of size n x p x k (or n x p if k=1,
% which is the default) such that each n x p matrix is orthonormal,
% i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for
% i = 1 : k if k > 1. Tangent vectors are represented as matrices the same
% size as points.
%
% The default retraction is QR-based: it is only a first-order retraction.
% To use the polar retraction (which is second order), run
%    M.retr = M.retr_polar;
% after creating M with this factory. This can be reverted with
%    M.retr = M.retr_qr;
% If used, you may also want to update M.invretr similarly.
% The polar retraction is also available in M.retr2.
%
% Set gpuflag = true to have points, tangent vectors and ambient vectors
% stored on the GPU. If so, computations can be done on the GPU directly.
%
% By default, k = 1 and gpuflag = false.
%
% See also: grassmannfactory rotationsfactory

% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, Dec. 30, 2012.
% Contributors: 
% Change log: 
%  July  5, 2013 (NB) : Added ehess2rhess.
%  Jan. 27, 2014 (BM) : Bug in ehess2rhess corrected.
%  June 24, 2014 (NB) : Added true exponential map and changed the randvec
%                       function so that it now returns a globally
%                       normalized vector, not a vector where each
%                       component is normalized (this only matters if k>1).
%  July 17, 2018 (NB) : Now both QR (default) and polar retractions are
%                       directly accessible, and their inverses are also
%                       implemented.
%  Aug.  2, 2018 (NB) : Added GPU support: just set gpuflag = true.
%  June 18, 2019 (NB) : Using qr_unique for retr and rand.
%  July  9, 2019 (NB) : Added a comment about QR retraction being first
%                       order only.
%  Jan.  8, 2021 (NB) : Added tangent2ambient+tangent2ambient_is_identity.
%  July  2, 2024 (NB) : Made polar retraction available as retr2.

    assert(n >= p, 'The dimension n must be larger than the dimension p.');
    
    if ~exist('k', 'var') || isempty(k)
        k = 1;
    end
    if ~exist('gpuflag', 'var') || isempty(gpuflag)
        gpuflag = false;
    end
    
    % If gpuflag is active, new arrays (e.g., via rand, randn, zeros, ones)
    % are created directly on the GPU; otherwise, they are created in the
    % usual way (in double precision).
    if gpuflag
        array_type = 'gpuArray';
    else
        array_type = 'double';
    end
    
    if k == 1
        M.name = @() sprintf('Stiefel manifold St(%d, %d)', n, p);
    elseif k > 1
        M.name = @() sprintf('Product Stiefel manifold St(%d, %d)^%d', n, p, k);
    else
        error('k must be an integer no less than 1.');
    end
    
    M.dim = @() k*(n*p - .5*p*(p+1));
    
    M.inner = @(x, d1, d2) d1(:).'*d2(:);
    
    M.norm = @(x, d) norm(d(:));
    
    M.typicaldist = @() sqrt(p*k);
    
    M.proj = @projection;
    function Up = projection(X, U)
        
        XtU = multiprod(multitransp(X), U);
        symXtU = multisym(XtU);
        Up = U - multiprod(X, symXtU);
        
% The code above is equivalent to, but faster than, the code below.
%         
%     Up = zeros(size(U));
%     function A = sym(A), A = .5*(A+A'); end
%     for i = 1 : k
%         Xi = X(:, :, i);
%         Ui = U(:, :, i);
%         Up(:, :, i) = Ui - Xi*sym(Xi'*Ui);
%     end

    end
    
    M.tangent = M.proj;
    
    M.tangent2ambient_is_identity = true;
    M.tangent2ambient = @(X, U) U;
    
    % For Riemannian submanifolds, converting a Euclidean gradient into a
    % Riemannian gradient amounts to an orthogonal projection.
    M.egrad2rgrad = M.proj;
    
    M.ehess2rhess = @ehess2rhess;
    function rhess = ehess2rhess(X, egrad, ehess, H)
        XtG = multiprod(multitransp(X), egrad);
        symXtG = multisym(XtG);
        HsymXtG = multiprod(H, symXtG);
        rhess = projection(X, ehess - HsymXtG);
    end
    
    M.retr_qr = @retraction_qr;
    function Y = retraction_qr(X, U, t)
        % It is necessary to call qr_unique rather than simply qr to ensure
        % this is a retraction, to avoid spurious column sign flips.
        if nargin < 3
            Y = qr_unique(X + U);
        else
            Y = qr_unique(X + t*U);
        end
    end

    M.invretr_qr = @invretr_qr;
    function U = invretr_qr(X, Y)
        XtY = multiprod(multitransp(X), Y);
        R = zeros(p, p, k, array_type);
        H = 2*eye(p, array_type);
        for kk = 1 : k
            % For each slice, assuming the inverse retraction is well
            % defined for the given inputs, we have:
            %   X + U = YR
            % Left multiply with X' to get
            %   I + X'U = X'Y R
            % Since X'U is skew symmetric for a tangent vector U at X, add
            % up this equation with its transpose to get:
            %   2I = (X'Y) R + R' (X'Y)'
            % Contrary to the polar factorization, here R is not symmetric
            % but it is upper triangular. As a result, this is not a
            % Sylvester equation and we must solve it differently.
            R(:, :, kk) = solve_for_triu(XtY(:, :, kk), H);
            % Then,
            %   U = YR - X
            % which is what we compute below.
        end
        U = multiprod(Y, R) - X;
    end
    
    M.retr_polar = @retraction_polar;
    function Y = retraction_polar(X, U, t)
        if nargin < 3
            Y = X + U;
        else
            Y = X + t*U;
        end
        for kk = 1 : k
            [u, s, v] = svd(Y(:, :, kk), 'econ'); %#ok
            Y(:, :, kk) = u*v';
        end
    end
    
    M.invretr_polar = @invretr_polar;
    function U = invretr_polar(X, Y)
        XtY = multiprod(multitransp(X), Y);
        MM = zeros(p, p, k, array_type);
        H = 2*eye(p, array_type);
        for kk = 1 : k
            % For each slice, assuming the inverse retraction is well
            % defined for the given inputs, we have:
            %   X + U = YM
            % Left multiply with X' to get
            %   I + X'U = X'Y M
            % Since X'U is skew symmetric for a tangent vector U at X, add
            % up this equation with its transpose to get:
            %   2I = (X'Y) M + M' (X'Y)'
            %      = (X'Y) M + M (X'Y)'   since M is symmetric.
            % Solve for M symmetric with a call to sylvester:
            MM(:, :, kk) = sylvester_nochecks(XtY(:, :, kk), XtY(:, :, kk)', H);
            % Note that the above is really a Lyapunov equation: it could
            % be solved faster by exploiting the fact the same matrix
            % appears twice on the left, with one the transpose of the
            % other. Then,
            %   U = YM - X
            % which is what we compute below.
        end
        U = multiprod(Y, MM) - X;
    end
    
    % By default, we use the QR retraction, which is first order.
    M.retr = M.retr_qr;
    M.invretr = M.invretr_qr;

    % The polar retraction is second order, and tagged as such in retr2.
    M.retr2 = M.retr_polar;

    M.exp = @exponential;
    function Y = exponential(X, U, t)
        if nargin == 2
            tU = U;
        else
            tU = t*U;
        end
        Y = zeros(size(X), array_type);
        I = eye(p, array_type);
        Z = zeros(p, array_type);
        for kk = 1 : k
            % From a formula by Ross Lippert, Example 5.4.2 in AMS08.
            Xkk = X(:, :, kk);
            Ukk = tU(:, :, kk);
            Y(:, :, kk) = [Xkk Ukk] * ...
                         expm([Xkk'*Ukk , -Ukk'*Ukk ; I , Xkk'*Ukk]) * ...
                         [ expm(-Xkk'*Ukk) ; Z ];
        end
        
    end

    M.hash = @(X) ['z' hashmd5(X(:))];
    
    M.rand = @() qr_unique(randn(n, p, k, array_type));
    
    M.randvec = @randomvec;
    function U = randomvec(X)
        U = projection(X, randn(n, p, k, array_type));
        U = U / norm(U(:));
    end
    
    M.lincomb = @matrixlincomb;
    
    M.zerovec = @(x) zeros(n, p, k, array_type);
    
    M.transp = @(x1, x2, d) projection(x2, d);
    
    M.vec = @(x, u_mat) u_mat(:);
    M.mat = @(x, u_vec) reshape(u_vec, [n, p, k]);
    M.vecmatareisometries = @() true;

    
    % Automatically convert a number of tools to support GPU.
    if gpuflag
        M = factorygpuhelper(M);
    end
    
end
