function M = grassmannfactory(n, p, k, gpuflag)
% Returns a manifold struct to optimize over the space of vector subspaces.
%
% function M = grassmannfactory(n, p)
% function M = grassmannfactory(n, p, k)
% function M = grassmannfactory(n, p, k, gpuflag)
%
% Grassmann manifold: each point on this manifold is a collection of k
% vector subspaces of dimension p embedded in R^n.
%
% The metric is obtained by making the Grassmannian a Riemannian quotient
% manifold of the Stiefel manifold, i.e., the manifold of orthonormal
% matrices, itself endowed with a metric by making it a Riemannian
% submanifold of the Euclidean space, endowed with the usual inner product.
% In short: it is the usual metric used in most cases.
% 
% This structure deals with matrices X of size n x p x k (or n x p if
% k = 1, which is the default) such that each n x p matrix is orthonormal,
% i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for
% i = 1 : k if k > 1. Each n x p matrix is a numerical representation of
% the vector subspace its columns span.
%
% The retraction is based on a polar factorization and is second order.
%
% Set gpuflag = true to have points, tangent vectors and ambient vectors
% stored on the GPU. If so, computations can be done on the GPU directly.
%
% By default, k = 1 and gpuflag = false.
%
% See also: stiefelfactory grassmanncomplexfactory grassmanngeneralizedfactory

% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, Dec. 30, 2012.
% Contributors: 
% Change log: 
%   March 22, 2013 (NB):
%       Implemented geodesic distance.
% 
%   April 17, 2013 (NB):
%       Retraction changed to the polar decomposition, so that the vector
%       transport is now correct, in the sense that it is compatible with
%       the retraction, i.e., transporting a tangent vector G from U to V
%       where V = Retr(U, H) will give Z, and transporting GQ from UQ to VQ
%       will give ZQ: there is no dependence on the representation, which
%       is as it should be. Notice that the polar factorization requires an
%       SVD whereas the qfactor retraction requires a QR decomposition,
%       which is cheaper. Hence, if the retraction happens to be a
%       bottleneck in your application and you are not using vector
%       transports, you may want to replace the retraction with a qfactor.
% 
%   July  4, 2013 (NB):
%       Added support for the logarithmic map 'log'.
%
%   July  5, 2013 (NB):
%       Added support for ehess2rhess.
%
%   June 24, 2014 (NB):
%       Small bug fix in the retraction, and added final
%       re-orthonormalization at the end of the exponential map. This
%       follows discussions on the forum where it appeared there is a
%       significant loss in orthonormality without that extra step. Also
%       changed the randvec function so that it now returns a globally
%       normalized vector, not a vector where each component is normalized
%       (this only matters if k>1).
%
%   July 8, 2018 (NB):
%       Inverse retraction implemented.
%
%   Aug. 3, 2018 (NB):
%       Added GPU support: just set gpuflag = true.
%
%   Apr. 19, 2019 (NB):
%       ehess2rhess: to ensure horizontality, it makes sense to project
%       last, same as in stiefelfactory.
%
%   May 3, 2019 (NB):
%       Added explanation about vector transport relation to retraction.
%
%   Nov. 13, 2019 (NB):
%       Added pairmean function.
%
%   Jan. 8, 2021 (NB)
%       Added tangent2ambient/tangent2ambient_is_identity pair.
%       Here, 'ambient' refers to the total space.

    assert(n >= p, ...
           ['The dimension n of the ambient space must be larger ' ...
            'than the dimension p of the subspaces.']);
    
    if ~exist('k', 'var') || isempty(k)
        k = 1;
    end
    if ~exist('gpuflag', 'var') || isempty(gpuflag)
        gpuflag = false;
    end
    
    % If gpuflag is active, new arrays (e.g., via rand, randn, zeros, ones)
    % are created directly on the GPU; otherwise, they are created in the
    % usual way (in double precision).
    if gpuflag
        array_type = 'gpuArray';
    else
        array_type = 'double';
    end
    
    if k == 1
        M.name = @() sprintf('Grassmann manifold Gr(%d, %d)', n, p);
    elseif k > 1
        M.name = @() sprintf('Multi Grassmann manifold Gr(%d, %d)^%d', ...
                             n, p, k);
    else
        error('k must be an integer no less than 1.');
    end
    
    M.dim = @() k*p*(n-p);
    
    M.inner = @(x, d1, d2) d1(:).'*d2(:);
    
    M.norm = @(x, d) norm(d(:));
    
    M.dist = @distance;
    function d = distance(x, y)
        square_d = 0;
        XtY = multiprod(multitransp(x), y);
        for kk = 1 : k
            cos_princ_angle = svd(XtY(:, :, kk));
            % For x and y closer than ~sqrt(eps), this function is
            % inaccurate, and typically returns values close to ~sqrt(eps).
            square_d = square_d + sum(real(acos(cos_princ_angle)).^2);
        end
        d = sqrt(square_d);
    end
    
    M.typicaldist = @() sqrt(p*k);
    
    % Orthogonal projection of an ambient vector U to the horizontal space
    % at X.
    M.proj = @projection;
    function Up = projection(X, U)
        
        XtU = multiprod(multitransp(X), U);
        Up = U - multiprod(X, XtU);

    end
    
    M.tangent = M.proj;
    
    M.tangent2ambient_is_identity = true;
    M.tangent2ambient = @(X, U) U;
    
    M.egrad2rgrad = M.proj;
    
    M.ehess2rhess = @ehess2rhess;
    function rhess = ehess2rhess(X, egrad, ehess, H)
        XtG = multiprod(multitransp(X), egrad);
        HXtG = multiprod(H, XtG);
        rhess = projection(X, ehess - HXtG);
    end
    
    M.retr = @retraction;
    function Y = retraction(X, U, t)
        if nargin < 3
            Y = X + U;
        else
            Y = X + t*U;
        end
        for kk = 1 : k
        
            % Compute the polar factorization of Y = X+tU
            [u, s, v] = svd(Y(:, :, kk), 'econ'); %#ok
            Y(:, :, kk) = u*v';
            
            % Another way to compute this retraction uses QR instead of SVD.
            % As compared with the Stiefel factory, we do not need to
            % worry about flipping signs of columns here, since only
            % the column space is important, not the actual columns.
            % We prefer the polar factor to the Q-factor computation for
            % reasons explained below: see M.transp.
            %
            % [Q, unused] = qr(Y(:, :, kk), 0); %#ok
            % Y(:, :, kk) = Q;
            
        end
    end
    
    % This inverse retraction is valid for both the QR retraction and the
    % polar retraction.
    M.invretr = @invretr;
    function U = invretr(X, Y)
        XtY = multiprod(multitransp(X), Y);
        U = zeros(n, p, k, array_type);
        for kk = 1 : k
            U(:, :, kk) = Y(:, :, kk) / XtY(:, :, kk);
        end
        U = U - X;
    end
    
    % See Eq. (2.65) in Edelman, Arias and Smith 1998.
    M.exp = @exponential;
    function Y = exponential(X, U, t)
        if nargin == 3
            tU = t*U;
        else
            tU = U;
        end
        Y = zeros(size(X), array_type);
        for kk = 1 : k
            [u, s, v] = svd(tU(:, :, kk), 0);
            cos_s = diag(cos(diag(s)));
            sin_s = diag(sin(diag(s)));
            Y(:, :, kk) = X(:, :, kk)*v*cos_s*v' + u*sin_s*v';
            % From numerical experiments, it seems necessary to
            % re-orthonormalize. This is overall quite expensive.
            [q, unused] = qr(Y(:, :, kk), 0); %#ok
            Y(:, :, kk) = q;
        end
    end

    % Test code for the logarithm:
    % Gr = grassmannfactory(5, 2, 3);
    % x = Gr.rand()
    % y = Gr.rand()
    % u = Gr.log(x, y)
    % Gr.dist(x, y) % These two numbers should
    % Gr.norm(x, u) % be the same.
    % z = Gr.exp(x, u) % z needs not be the same matrix as y, but it should
    % v = Gr.log(x, z) % be the same point as y on Grassmann: dist almost 0.
    M.log = @logarithm;
    function U = logarithm(X, Y)
        U = zeros(n, p, k, array_type);
        for kk = 1 : k
            x = X(:, :, kk);
            y = Y(:, :, kk);
            ytx = y.'*x;
            At = y.'-ytx*x.';
            Bt = ytx\At;
            [u, s, v] = svd(Bt.', 'econ');

            u = u(:, 1:p);
            s = diag(s);
            s = s(1:p);
            v = v(:, 1:p);

            U(:, :, kk) = u*diag(atan(s))*v.';
        end
    end

    M.hash = @(X) ['z' hashmd5(X(:))];
    
    M.rand = @random;
    function X = random()
        X = randn(n, p, k, array_type);
        for kk = 1 : k
            [Q, unused] = qr(X(:, :, kk), 0); %#ok
            X(:, :, kk) = Q;
        end
    end
    
    M.randvec = @randomvec;
    function U = randomvec(X)
        U = projection(X, randn(n, p, k, array_type));
        U = U / norm(U(:));
    end
    
    M.lincomb = @matrixlincomb;
    
    M.zerovec = @(x) zeros(n, p, k, array_type);
    
    % This transport is compatible with the polar retraction, in the
    % following sense:
    %
    % n = 7; p = 3;
    % Gr = grassmannfactory(n, p);
    % X = Gr.rand();
    % U = Gr.randvec(X);
    % V = Gr.randvec(X);
    % [Q, ~] = qr(randn(p));
    % Gr.transp(X*Q, Gr.retr(X*Q, V*Q), U*Q) % these two
    % Gr.transp(X, Gr.retr(X, V), U)*Q       % are equal (up to eps)
    %
    % That is, if we transport U, the horizontal lift of some tangent
    % vector at X, to Y, and Y = Retr_X(V) with V the horizontal lift of
    % some tangent vector at X, we get the horizontal lift of some tangent
    % vector at Y. If we displace X, U, V to XQ, UQ, VQ for some arbitrary
    % orthogonal matrix Q, we get a horizontal lift of some vector at YQ.
    % Importantly, these two vectors are the lifts of the same tangent
    % vector, only lifted at Y and YQ.
    %
    % However, this vector transport is /not/ fully invariant, in the sense
    % that transporting U from X to some arbitrary Y may well yield the
    % lift of a different vector when compared to transporting U from X
    % to YQ, where Q is an arbitrary orthogonal matrix, even though YQ is
    % equivalent to Y. Specifically:
    %
    % Y = Gr.rand();
    % Gr.transp(X, Y*Q, U) - Gr.transp(X, Y, U)*Q   % this is not zero.
    %
    % However, the following vectors are equal:
    %
    % Gr.transp(X, Y*Q, U) - Gr.transp(X, Y, U)     % this *is* zero.
    %
    % For this to be a proper vector transport from [X] to [Y] in general,
    % assuming X'Y is invertible, one should multiply the output of this
    % function on the right with the polar factor of X'*Y, that is,
    % multiply by u*v' where [u, s, v] = svd(X'*Y), for each slice.
    M.transp = @(X, Y, U) projection(Y, U);
    
    % The mean of two points is here defined as the midpoint of a
    % minimizing geodesic connecting the two points. If the log of (X1, X2)
    % is not uniquely defined, then the returned object may not be
    % meaningful; in other words: this works best if (X1, X2) are close.
    M.pairmean = @pairmean;
    function Y = pairmean(X1, X2)
        Y = M.exp(X1, .5*M.log(X1, X2));
    end
    
    M.vec = @(x, u_mat) u_mat(:);
    M.mat = @(x, u_vec) reshape(u_vec, [n, p, k]);
    M.vecmatareisometries = @() true;

    
    % Automatically convert a number of tools to support GPU.
    if gpuflag
        M = factorygpuhelper(M);
    end

end
