function M = stiefelstackedfactory(m, d, k)
% Stiefel(k, d)^m, represented as matrices of size m*d-by-k.
%
% function M = stiefelstackedfactory(m, d, k)
%
% Points on this manifold are matrices Y of size n x k, with n = m*d.
% Y is thought of as m matrices of size d x k each, stacked on top of each
% other. Call them Y1, ..., Ym. Each Yi is an orthonormal matrix, that is,
% its d rows are unit norm and are orthogonal to each other. Thus, this
% geometry is a product of Stiefel manifolds.
% 
% To easily transform matrices Y to 3D arrays Y3 of size d x k x m such
% that each slice Y3(:, :, i) corresponds to one of the matrices Yi, use
% the functions
% 
%    Y3 = M.to3D(Y)   and   Y = M.to2D(Y3).
%
% The ambient space R^(nxk) is endowed with the usual inner product
% <A, B> = trace(A'*B). This inner product is restricted to the tangent
% spaces of the present manifold, thus making it a Riemannian submanifold
% of the Euclidean space R^(nxk). Tangent vectors are represented as
% matrices of the same size as Y, and can likewise be converted to 3D
% arrays and back using to3D() and to2D().
%
% In dealing with this geometry, especially when dealing with the 3D array
% representations of points and tangent vectors, the tools multiprod,
% multitransp, multitrace, multiscale etc. available in Manopt are often
% useful.
%
% See also: stiefelfactory obliquefactory multiprod multitransp

% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, May 4, 2015.
% Contributors: 
% Change log: 
%   July 2, 2024 (NB): marked the retraction as second order with M.retr2.

    assert(k >= d, 'k must be at least as large as d.');

    n = m*d;
    
    M.name = @() sprintf('Manifold of %d orthonormal matrices of size %dx%d, stacked', m, d, k);
    
    M.dim = @() m*(k*d - .5*d*(d+1));
    
    M.size = @() [m, d, k];
    
    M.inner = @(x, d1, d2) d1(:).'*d2(:);
    
    M.norm = @(x, d) norm(d(:));
    
    M.typicaldist = @() sqrt(M.dim());

    % Convert a dxkxm matrix to an nxk matrix
    M.to2D = @to2D;
    function A2 = to2D(A3)
        A2 = reshape(multitransp(A3), [k, m*d])';
    end

    % Convert an nxk matrix to a dxkxm matrix
    M.to3D = @to3D;
    function A3 = to3D(A2)
        A3 = multitransp(reshape(A2', [k, d, m]));
    end

    % Given 2 3D matrices A and B of size dxkxm, returns a 3D matrix C of
    % size dxdxm such that each slice C(:, :, i) is the symmetric part of
    % the product A(:, :, i) * B(:, :, i)'. The name is short for
    % "symmetric-block-diagonal", because if A and B were transformed to
    % their 2D equivalents via to2D, then the output would contain the
    % symmetric parts of the diagonal blocks of A*B'.
    M.symbdiag = @symbdiag;
    function C = symbdiag(A, B)
        C = multisym(multiprod(A, multitransp(B)));
    end
    
    % Orthogonal projection from the ambient space R^(nxk) to the tangent
    % space at X.
    M.proj = @projection;
    function Zt = projection(Y, Z)
        Y3 = to3D(Y);
        Z3 = to3D(Z);
        Lambda = symbdiag(Y3, Z3);
        Zt3 = Z3 - multiprod(Lambda, Y3);
        Zt = to2D(Zt3);
    end    
    
    M.tangent = M.proj;
    
    M.egrad2rgrad = M.proj;
    
    M.ehess2rhess = @ehess2rhess;
    function rhess = ehess2rhess(Y, egrad, ehess, Ydot)
        Y3 = to3D(Y);
        Ydot3 = to3D(Ydot);
        egrad3 = to3D(egrad);
        C = symbdiag(Y3, egrad3);
        CYdot = to2D(multiprod(C, Ydot3));
        rhess = projection(Y, ehess - CYdot);
    end
    
    M.retr = @retraction;
    M.retr2 = M.retr; % this is a second-order retraction
    function Y = retraction(Y, U, t)
        if nargin < 3
            t = 1.0;
        end
        Y = Y + t*U;
        Y3 = to3D(Y);
        for i = 1 : m
            % Orthonormalize the rows of Y3(:, :, i):
            [u, s, v] = svd(Y3(:, :, i), 'econ'); %#ok<ASGLU>
            Y3(:, :, i) = u*v';
            % Alternatively, one could also use qr_unique as retraction.
        end
        Y = to2D(Y3);
    end
    
    M.exp = @exponential;
    function Y = exponential(Y, U, t)
        if nargin == 2
            t = 1;
        end
        tU3 = multitransp(to3D(t*U));
        Y3 = multitransp(to3D(Y));
        % From a formula by Ross Lippert, Example 5.4.2 in AMS08.
        for i = 1 : m
            X = Y3(:, :, i);
            Z = tU3(:, :, i);
            Y3(:, :, i) = [X, Z] * ...
                          expm([  X'*Z , -Z'*Z ; eye(d) , X'*Z]) * ...
                          [ expm(-X'*Z) ; zeros(d) ];
            % We may loose orthonormality here. Just to be sure:
            [u, s, v] = svd(Y3(:, :, i), 'econ'); %#ok<ASGLU>
            Y3(:, :, i) = u*v';
        end
        Y = to2D(multitransp(Y3));
    end

    M.hash = @(Y) ['z' hashmd5(Y(:))];
    
    M.rand = @random;
    function Y = random()
        Y3 = zeros(d, k, m);
        for i = 1 : m
            [Q, unused] = qr(randn(k, d), 0); %#ok<ASGLU>
            Y3(:, :, i) = Q';
        end
        Y = to2D(Y3);
    end
    
    M.randvec = @randomvec;
    function U = randomvec(Y)
        U = projection(Y, randn(n, k));
        U = U / M.norm(Y, U);
    end
    
    M.lincomb = @matrixlincomb;
    
    M.zerovec = @(x) zeros(n, k);
    
    M.transp = @(x1, x2, u) projection(x2, u);
    
    M.vec = @(x, u_mat) u_mat(:);
    M.mat = @(x, u_vec) reshape(u_vec, [n, k]);
    M.vecmatareisometries = @() true;

end
