function M = grassmanngeneralizedfactory(n, p, B)
% Returns a manifold struct of "scaled" vector subspaces.
%
% function M = grassmanngeneralizedfactory(n, p)
% function M = grassmanngeneralizedfactory(n, p, B)
%
% Generalized Grassmann manifold: each point on this manifold is a
% collection of "scaled" vector subspaces of dimension p embedded in R^n.
% The scaling is due to the symmetric positive definite matrix B.
%
% When B is identity, the manifold is the standard Grassmann manifold.
%
% The metric is obtained by viewing the generalized Grassmannian
% a Riemannian quotient manifold of the generalized Stiefel manifold, 
% which is the manifold of "scaled" orthonormal matrices. Specifically, 
% the scaled Stiefel manifold is the set {X : X'*B*X = I}. 
% The generalized Grassmann manifold is the Grassmannian of the 
% generalized Stiefel manifold.
%
% The generalized Stiefel manifold is endowed with a scaled metric
% by viewing it as a Riemannian submanifold of the Euclidean space, which
% is again endowed with the scaled inner product.
%
% Some notions (not all) are from Section 4.5 of the paper
% "The geometry of algorithms with orthogonality constraints",
% A. Edelman, T. A. Arias, S. T. Smith, SIMAX, 1998.
%
% Paper link: http://arxiv.org/abs/physics/9806030.
%
% 
% Note: some computations such as restricted_svd, distance, logarithm, and 
% exponential are new and we believe them to be correct.
% Also, we hope that the computations are numerically stable.
% In case some things do not work out as expected or there is some trouble,
% please contact us at http://www.manopt.org.
%
% Note: egrad2rgrad and ehess2rhess involve solving linear systems in B. If
% this is a bottleneck for a specific application, then a way forward is to
% create a modified version of this file which preprocesses B to speed this
% up (typically, by computing a Cholesky factorization of it, then calling
% an appropriate solver).
%
% See also: stiefelgeneralizedfactory  stiefelfactory  grassmannfactory


% This file is part of Manopt: www.manopt.org.
% Original author: Bamdev Mishra, June 30, 2015.
% Contributors:
%
% Change log:
%   Jan. 4, 2021 (NB):
%       Moved symm() function to end of file for compatibility with Octave 6.1.0
    
    assert(n >= p, ...
        ['The dimension n of the ambient space must be larger ' ...
        'than the dimension p of the subspaces.']);
    
    if ~exist('B', 'var') || isempty(B)
        B = speye(n); % Standard Grassmann manifold.
    end
    
    M.name = @() sprintf('Generalized Grassmann manifold Gr(%d, %d)', n, p);
    
    M.dim = @() p*(n - p);   
    
    M.inner = @(X, eta, zeta) trace(eta'*(B*zeta)); % Scaled metric, but horizontally invariant.
    
    M.norm = @(X, eta) sqrt(M.inner(X, eta, eta));
    
    M.dist = @distance; 
    function d = distance(X, Y)
        XtBY = X'*(B*Y); % XtY ---> XtBY
        cos_princ_angle = svd(XtBY); % svd(XtY) ---> svd(XtBY)
        % Two next instructions not necessary: the imaginary parts that
        % would appear if the cosines are not between -1 and 1, when
        % passed to the acos function, would be very small, and would
        % thus vanish when the norm is taken.
        % cos_princ_angle = min(cos_princ_angle,  1);
        % cos_princ_angle = max(cos_princ_angle, -1);
        square_d = norm(acos(cos_princ_angle))^2;
        
        d = sqrt(square_d);
    end
    
    M.typicaldist = @() sqrt(p);
    
    
    % Orthogonal projection of an ambient vector U onto the 
    % horizontal space at X.
    M.proj = @projection;
    function Up = projection(X, U)
        BX = B*X;
        
        % Projection onto the tangent space
        % U = U - X*symm(BX'*U);
        % Projection onto the horizontal space
        % Up = U - X*skew(BX'*U);
        
        Up = U - X*(BX'*U);
    end
    
    M.tangent = M.proj;
    
    M.egrad2rgrad = @egrad2rgrad;
    function rgrad = egrad2rgrad(X, egrad)
        
        % First, scale egrad according to the scaled metric in the
        % Euclidean space.
        egrad_scaled = B\egrad;
        
        % Second, project onto the tangent space.
        % No need to project onto the horizontal space as
        % by the Riemannian submersion theory, this quantity automatically
        % belongs to the horizontal space.
        %
        %
        % rgrad = egrad_scaled - X*symm((B*X)'*egrad_scaled);
        %
        % Verify that symm(BX'*egrad_scaled) = symm(X'*egrad).
        
        rgrad = egrad_scaled - X*symm(X'*egrad);
    end
    
    
    M.ehess2rhess = @ehess2rhess;
    function rhess = ehess2rhess(X, egrad, ehess, H)
        egraddot = ehess;
        Xdot = H;
        
        % Directional derivative of the Riemannian gradient.
        egrad_scaleddot = B\egraddot;
        rgraddot = egrad_scaleddot - Xdot*symm(X'*egrad)...
            - X*symm(Xdot'*egrad)...
            - X*symm(X'*egraddot);
        
        % Project onto the horizontal space.
        rhess = M.proj(X, rgraddot);
    end
    
    
    M.retr = @retraction;
    function Y = retraction(X, U, t)
        if nargin < 3
            t = 1.0;
        end
        Y = guf(X + t*U); % Ensure that Y'*B*Y is identity.
    end
    
    
    M.exp = @exponential;
    function Y = exponential(X, U, t)
        if nargin == 3
            tU = t*U;
        else
            tU = U;
        end
        
        % restricted_svd is defined later in the file.
        [u, s, v] = restricted_svd(tU);% svd(tU, 0) ---> restricted_svd(tU).
        cos_s = diag(cos(diag(s)));
        sin_s = diag(sin(diag(s)));
        Y = X*v*cos_s*v' + u*sin_s*v';% Verify that Y'*B*Y is identity
        
        % From numerical experiments, it seems necessary to
        % re-orthonormalize.
        Y = guf(Y);% Ensure that Y'*B*Y is identity.
    end
    
    
    
    % Test code for the logarithm:
    % gGr = grassmanngeneralizedfactory(5, 2, diag(rand(5,1)));
    % x = gGr.rand()
    % y = gGr.rand()
    % u = gGr.log(x, y)
    % gGr.dist(x, y) % These two numbers should
    % gGr.norm(x, u) % be the same.
    % z = gGr.exp(x, u) % z needs not be the same matrix as y, but it should
    % v = gGr.log(x, z) % be the same point as y on Grassmann: dist almost 0.
    % gGr.dist(z, y)
    M.log = @logarithm;
    function U = logarithm(X, Y)
        YtBX = Y'*(B*X); % YtX ---> YtBX.
        At = (Y' - YtBX*X');
        Bt = YtBX\At;
        [u, s, v] = restricted_svd(Bt');% svd(Bt', 'econ') ---> restricted_svd(Bt').
        
        u = u(:, 1:p);
        s = diag(s);
        s = s(1:p);
        v = v(:, 1:p);
        U = u*diag(atan(s))*v'; % A horizontal vector, i.e., U'*(B*X) is zero.
    end
    
    
    M.hash = @(X) ['z' hashmd5(X(:))];
    
    M.rand = @random;
    function X = random()
        X = guf(randn(n, p)); % Ensure that X'*B*X is identity;
    end
    
    M.randvec = @randomvec;
    function U = randomvec(X)
        U = projection(X, randn(n, p));
        U = U / norm(U(:));
    end
    
    M.lincomb = @matrixlincomb;
    
    M.zerovec = @(X) zeros(n, p);
    
    % This transport is compatible with the generalized polar retraction.
    M.transp = @(X1, X2, d) projection(X2, d);
    
    M.vec = @(X, u_mat) u_mat(:);
    M.mat = @(X, u_vec) reshape(u_vec, [n, p]);
    M.vecmatareisometries = @() false;
    
    % Some auxiliary functions    
    function X = guf(Y)
        % Generalized polar decomposition of an n-by-p matrix Y.
        % X'*B*X is identity.
        
        % Method 1
        [u, ~, v] = svd(Y, 0);
  
        % Instead of the following three steps, an equivalent, but an 
        % expensive, way is to do X = u*(sqrtm(u'*(B*u))\(v')).
        [q, ssquare] = eig(u'*(B*u));
        qsinv = q/sparse(diag(sqrt(diag(ssquare))));
        X = u*((qsinv*q')*v'); % X'*B*X is identity.
        
        
        % Another computation using restricted_svd
        % [u, ~, v] = restricted_svd(Y);
        % X = u*v'; % X'*B*X is identity.
        
    end
    
    function [u, s, v] = restricted_svd(Y)
        % We compute a thin svd-like decomposition of an n-by-p matrix Y 
        % into matrices u, s, and v such that u is an n-by-p matrix
        % with u'*B*u being identity, s is a p-by-p diagonal matrix 
        % with positive entries, and v is a p-by-p orthogonal matrix.
        % Y = u*s*v'.
        
        [v, ssquare] = eig(symm(Y'*(B*Y))); % Y*B*Y is positive definite
        ssquarevec = diag(ssquare);
        
        s = sparse(diag(abs(sqrt(ssquarevec))));
        u = Y*(v/s); % u'*B*u is identity.
    end
    
end

function A = symm(Z)
    A = (Z + Z')/2;
end
