function M = grassmanncomplexfactory(n, p, k)
% Returns a manifold struct to optimize over the set of subspaces in C^n.
%
% function M = grassmanncomplexfactory(n, p)
% function M = grassmanncomplexfactory(n, p, k)
%
% Complex Grassmann manifold: each point on this manifold is a collection
% of k vector subspaces of dimension p embedded in C^n.
%
% The metric is obtained by making the Grassmannian a Riemannian quotient
% manifold of the complex Stiefel manifold, i.e., the manifold of
% orthonormal matrices, itself endowed with a metric by making it a
% Riemannian submanifold of the Euclidean space, endowed with the usual
% real-trace inner product, that is, it is the usual metric for the complex
% plane identified with R^2.
% 
% This structure deals with complex matrices X of size n x p x k
% (or n x p if k = 1, which is the default) such that each n x p matrix is
% orthonormal, i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) =
% eye(p) for i = 1 : k if k > 1. Each n x p matrix is a numerical
% representation of the vector subspace its columns span.
%
% By default, k = 1.
%
% See also: grassmannfactory, stiefelcomplexfactory, grassmanngeneralizedfactory

% This file is part of Manopt: www.manopt.org.
% Original author: Hiroyuki Sato, May 21, 2015.
% Contributors: 
% Change log: 

    assert(n >= p, ...
           ['The dimension n of the ambient space must be larger ' ...
            'than the dimension p of the subspaces.']);
    
    if ~exist('k', 'var') || isempty(k)
        k = 1;
    end
    
    if k == 1
        M.name = @() sprintf('Complex Grassmann manifold Gr(%d, %d)', n, p);
    elseif k > 1
        M.name = @() sprintf(['Multi complex Grassmann manifold ' ...
            'Gr(%d, %d)^%d'], n, p, k);
    else
        error('k must be an integer no less than 1.');
    end
    
    M.dim = @() 2*k*p*(n-p); %! k*p*(n-p) -> 2*k*p*(n-p)
    
    M.inner = @(x, d1, d2) real(d1(:)'*d2(:)); %! trace -> real-trace
    
    M.norm = @(x, d) norm(d(:));
    
    M.dist = @distance;
    function d = distance(x, y)
        principal_angles = zeros(p, k);
        XHY = multiprod(multihconj(x), y); %! XtY -> XHY, multitransp -> multihconj
        for i = 1 : k
            cos_princ_angle = svd(XHY(:, :, i));
            principal_angles(:, i) = acos(cos_princ_angle);
        end
        d = norm(real(principal_angles), 'fro');
    end
    
    M.typicaldist = @() sqrt(p*k);
    
    % Orthogonal projection of an ambient vector U to the horizontal space
    % at X.
    M.proj = @projection;
    function Up = projection(X, U)
        
        XHU = multiprod(multihconj(X), U); %! XtU -> XHU, multitransp -> multihconj
        Up = U - multiprod(X, XHU); %! XtU -> XHU

    end
    
    M.tangent = M.proj;
    
    M.egrad2rgrad = M.proj;
    
    M.ehess2rhess = @ehess2rhess;
    function rhess = ehess2rhess(X, egrad, ehess, H)
        PXehess = projection(X, ehess);
        XHG = multiprod(multihconj(X), egrad); %! XtG -> XHG, multitransp -> multihconj
        HXHG = multiprod(H, XHG); %! HXtG -> HXHG, XtG -> XHG
        rhess = PXehess - HXHG; %! HXtG -> HXHG
    end
    
    M.retr = @retraction;
    function Y = retraction(X, U, t)
        if nargin < 3
            t = 1.0;
        end
        Y = X + t*U;
        for i = 1 : k 
        
            % Compute the polar factorization of Y = X+tU
            [u, s, v] = svd(Y(:, :, i), 'econ'); %#ok
            Y(:, :, i) = u*v';
            
            % Another popular retraction uses QR instead of SVD.
            % As compared with the Stiefel factory, we do not need to
            % worry about flipping signs of columns here, since only
            % the column space is important, not the actual columns.
            % [Q, unused] = qr(Y(:, :, i), 0); %#ok
            % Y(:, :, i) = Q;
            
        end
    end
    
    M.exp = @exponential;
    function Y = exponential(X, U, t)
        if nargin == 3
            tU = t*U;
        else
            tU = U;
        end
        Y = zeros(size(X));
        for i = 1 : k
            [u, s, v] = svd(tU(:, :, i), 0);
            cos_s = diag(cos(diag(s)));
            sin_s = diag(sin(diag(s)));
            Y(:, :, i) = X(:, :, i)*v*cos_s*v' + u*sin_s*v';
            % From numerical experiments, it seems necessary to
            % re-orthonormalize. This is overall quite expensive.
            [q, unused] = qr(Y(:, :, i), 0); %#ok
            Y(:, :, i) = q;
        end
    end

    % Test code for the logarithm:
    % Gr = grassmanncomplexfactory(5, 2, 3);
    % x = Gr.rand()
    % y = Gr.rand()
    % u = Gr.log(x, y)
    % Gr.dist(x, y) % These two numbers should
    % Gr.norm(x, u) % be the same.
    % z = Gr.exp(x, u) % z needs not be the same matrix as y, but it should
    % v = Gr.log(x, z) % be the same point as y on Grassmann: dist almost 0.
    M.log = @logarithm;
    function U = logarithm(X, Y)
        U = zeros(n, p, k);
        for i = 1 : k
            x = X(:, :, i);
            y = Y(:, :, i);
            yHx = y'*x; %! ytx -> yHx, y.' -> y'
            AH = y'-yHx*x'; %! At -> AH, x.' -> x', y.' -> y'
            BH = yHx\AH; %! Bt -> BH, ytx -> yHx, At -> AH
            [u, s, v] = svd(BH', 'econ'); %! Bt.' -> BH'

            u = u(:, 1:p);
            s = diag(s);
            s = s(1:p);
            v = v(:, 1:p);

            U(:, :, i) = u*diag(atan(s))*v'; %! v.' -> v'
        end
    end

    M.hash = @(X) ['z' hashmd5([real(X(:)); imag(X(:))])]; %! X(:) -> [real(X(:)); imag(X(:))]
    
    M.rand = @random;
    function X = random()
        X = zeros(n, p, k);
        for j = 1 : k
            [Q, unused] = qr(randn(n, p) + 1i*randn(n, p), 0); %#ok<NASGU> %! Complex version
            X(:, :, j) = Q;
        end
    end
    
    M.randvec = @randomvec;
    function U = randomvec(X)
        U = projection(X, randn(n, p, k) + 1i*randn(n, p, k)); %! Complex version
        U = U / norm(U(:));
    end
    
    M.lincomb = @matrixlincomb;
    
    M.zerovec = @(x) zeros(n, p, k);
    
    % This transport is compatible with the polar retraction.
    M.transp = @(x1, x2, d) projection(x2, d);
    
    M.vec = @(x, u_mat) [real(u_mat(:)) ; imag(u_mat(:))];
    M.mat = @(x, u_vec) reshape(u_vec(1:(n*p*k)) + 1i*u_vec((n*p*k+1):end), [n, p, k]);
    M.vecmatareisometries = @() true;

end
