function M = fixedrankembeddedfactory(m, n, k)
% Manifold struct to optimize fixed-rank matrices w/ an embedded geometry.
%
% function M = fixedrankembeddedfactory(m, n, k)
%
% Manifold of m-by-n real matrices of fixed rank k. This follows the
% embedded geometry described in Bart Vandereycken's 2013 paper:
% "Low-rank matrix completion by Riemannian optimization".
% 
% Paper link: http://arxiv.org/pdf/1209.3834.pdf
%
% A point X on the manifold is represented as a structure with three
% fields: U, S and V. The matrices U (mxk) and V (nxk) are orthonormal,
% while the matrix S (kxk) is any /diagonal/ full-rank matrix.
% Following the SVD formalism, X = U*S*V'. Note that the diagonal entries
% of S are not constrained to be nonnegative.
%
% Tangent vectors are represented as a structure with three fields: Up, M
% and Vp. The matrices Up (mxk) and Vp (mxk) obey Up'*U = 0 and Vp'*V = 0.
% The matrix M (kxk) is arbitrary. Such a structure corresponds to the
% following tangent vector in the ambient space of mxn matrices:
%   Z = U*M*V' + Up*V' + U*Vp'
% where (U, S, V) is the current point and (Up, M, Vp) is the tangent
% vector at that point.
%
% Vectors in the ambient space are best represented as mxn matrices. If
% these are low-rank, they may also be represented as structures with
% U, S, V fields, such that Z = U*S*V'. There are no restrictions on what
% U, S and V are, as long as their product as indicated yields a real, mxn
% matrix.
%
% The chosen geometry yields a Riemannian submanifold of the embedding
% space R^(mxn) equipped with the usual trace (Frobenius) inner product.
%
% The tools
%    X_triplet = M.matrix2triplet(X_matrix) and
%    X_matrix = M.triplet2matrix(X_triplet)
% can be used to easily convert between full matrix representation (as a
% matrix of size mxn) and triplet representation as a structure with fields
% U, S, V. The tool matrix2triplet also accepts an optional second input r
% to choose the rank of the triplet representation. By default, r = k. If
% the input matrix X_matrix has rank more than r, the triplet represents
% its best rank-r approximation in the Frobenius norm (truncated SVD).
% Note that these conversions are computationally expensive for large m
% and n: this is mostly useful for small matrices and for prototyping.
%
%
% Please cite the Manopt paper as well as the research paper:
%     @Article{vandereycken2013lowrank,
%       Title   = {Low-rank matrix completion by {Riemannian} optimization},
%       Author  = {Vandereycken, B.},
%       Journal = {SIAM Journal on Optimization},
%       Year    = {2013},
%       Number  = {2},
%       Pages   = {1214--1236},
%       Volume  = {23},
%       Doi     = {10.1137/110845768}
%     }
%
% See also: fixedrankfactory_2factors fixedrankfactory_3factors fixedranktensorembeddedfactory

% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, Dec. 30, 2012.
% Contributors: Bart Vandereycken, Eitan Levin
% Change log: 
%
%   Feb. 20, 2014 (NB):
%       Added function tangent to work with checkgradient.
%
%   June 24, 2014 (NB):
%       A couple modifications following Bart's feedback:
%       - The checksum (hash) was replaced for a faster alternative: it's a
%         bit less "safe" in that collisions could arise with higher
%         probability, but they're still very unlikely.
%       - The vector transport was changed.
%       The typical distance was also modified, hopefully giving the
%       trustregions method a better initial guess for the trust-region
%       radius, but that should be tested for different cost functions too.
%
%    July 11, 2014 (NB):
%       Added ehess2rhess and tangent2ambient, supplied by Bart.
%
%    July 14, 2014 (NB):
%       Added vec, mat and vecmatareisometries so that hessianspectrum now
%       works with this geometry. Implemented the tangent function.
%       Made it clearer in the code and in the documentation in what format
%       ambient vectors may be supplied, and generalized some functions so
%       that they should now work with both accepted formats.
%       It is now clearly stated that for a point X represented as a
%       triplet (U, S, V), the matrix S needs to be diagonal.
%
%    Sep.  6, 2018 (NB):
%       Removed M.exp() as it was not implemented.
%
%    March 20, 2019 (NB):
%       Added M.matrix2triplet and M.triplet2matrix to allow easy
%       conversion between matrix representations either as full matrices
%       or as triplets (U, S, V).
%
%    Dec. 14, 2019 (EL):
%       The original retraction code was repaced with a somewhat slower but
%       numerically more stable version. With the original code, trouble
%       could arise when the matrices Up, Vp defining the tangent vector
%       being retracted were ill-conditioned.
%
%    Jan. 28, 2020 (NB):
%       In retraction code, moved parameter t around to highlight the fact
%       that it comes up in only one computation.
%       Replaced vec/mat codes: they are still isometries, but they produce
%       representations of length k*(m+n+k) instead of m*n, which is much
%       more efficient: it only exceeds the true dimension by 2k^2. Also,
%       mat does not attempt to project to the tangent space (which it did
%       before but shouldn't): compose mat with tangent if that is the
%       desired effect.

    M.name = @() sprintf('Manifold of %dx%d matrices of rank %d', m, n, k);
    
    M.dim = @() (m+n-k)*k;
    
    M.inner = @(x, d1, d2) d1.M(:).'*d2.M(:) + d1.Up(:).'*d2.Up(:) ...
                                             + d1.Vp(:).'*d2.Vp(:);
    
    M.norm = @(x, d) sqrt(norm(d.M, 'fro')^2 + norm(d.Up, 'fro')^2 ...
                                             + norm(d.Vp, 'fro')^2);
    
    M.dist = @(x, y) error('fixedrankembeddedfactory.dist not implemented yet.');
    
    M.typicaldist = @() M.dim();
    
    % Given Z in tangent vector format, projects the components Up and Vp
    % such that they satisfy the tangent space constraints up to numerical
    % errors. If Z was indeed a tangent vector at X, this should barely
    % affect Z (it would not at all if we had infinite numerical accuracy).
    M.tangent = @tangent;
    function Z = tangent(X, Z)
        Z.Up = Z.Up - X.U*(X.U'*Z.Up);
        Z.Vp = Z.Vp - X.V*(X.V'*Z.Vp);
    end

    % For a given ambient vector Z, applies it to a matrix W. If Z is given
    % as a matrix, this is straightforward. If Z is given as a structure
    % with fields U, S, V such that Z = U*S*V', the product is executed
    % efficiently.
    function ZW = apply_ambient(Z, W)
        if ~isstruct(Z)
            ZW = Z*W;
        else
            ZW = Z.U*(Z.S*(Z.V'*W));
        end
    end

    % Same as apply_ambient, but applies Z' to W.
    function ZtW = apply_ambient_transpose(Z, W)
        if ~isstruct(Z)
            ZtW = Z'*W;
        else
            ZtW = Z.V*(Z.S'*(Z.U'*W));
        end
    end
    
    % Orthogonal projection of an ambient vector Z represented as an mxn
    % matrix or as a structure with fields U, S, V to the tangent space at
    % X, in a tangent vector structure format.
    M.proj = @projection;
    function Zproj = projection(X, Z)
            
        ZV = apply_ambient(Z, X.V);
        UtZV = X.U'*ZV;
        ZtU = apply_ambient_transpose(Z, X.U);

        Zproj.M = UtZV;
        Zproj.Up = ZV  - X.U*UtZV;
        Zproj.Vp = ZtU - X.V*UtZV';

    end

    M.egrad2rgrad = @projection;
    
    % Code supplied by Bart.
    % Given the Euclidean gradient at X and the Euclidean Hessian at X
    % along H, where egrad and ehess are vectors in the ambient space and H
    % is a tangent vector at X, returns the Riemannian Hessian at X along
    % H, which is a tangent vector.
    M.ehess2rhess = @ehess2rhess;
    function rhess = ehess2rhess(X, egrad, ehess, H)
        
        % Euclidean part
        rhess = projection(X, ehess);
        
        % Curvature part
        T = apply_ambient(egrad, H.Vp)/X.S;
        rhess.Up = rhess.Up + (T - X.U*(X.U'*T));
        T = apply_ambient_transpose(egrad, H.Up)/X.S;
        rhess.Vp = rhess.Vp + (T - X.V*(X.V'*T));
        
    end

    % Transforms a tangent vector Z represented as a structure (Up, M, Vp)
    % into a structure with fields (U, S, V) that represents that same
    % tangent vector in the ambient space of mxn matrices, as U*S*V'.
    % This matrix is equal to X.U*Z.M*X.V' + Z.Up*X.V' + X.U*Z.Vp'. The
    % latter is an mxn matrix, which could be too large to build
    % explicitly, and this is why we return a low-rank representation
    % instead. Note that there are no guarantees on U, S and V other than
    % that USV' is the desired matrix. In particular, U and V are not (in
    % general) orthonormal and S is not (in general) diagonal.
    % (In this implementation, S is identity, but this might change.)
    M.tangent2ambient_is_identity = false;
    M.tangent2ambient = @tangent2ambient;
    function Zambient = tangent2ambient(X, Z)
        Zambient.U = [X.U*Z.M + Z.Up, X.U];
        Zambient.S = eye(2*k);
        Zambient.V = [X.V, Z.Vp];
    end
    
    % This retraction is second order, following general results from
    % Absil, Malick, "Projection-like retractions on matrix manifolds",
    % SIAM J. Optim., 22 (2012), pp. 135-158.
    %
    % Notice that this retraction is only locally smooth. Indeed, the
    % following code exhibits a discontinuity when retracting from
    % X = [1 0 ; 0 0] along V = [0 0 ; 0 1]:
    %
    % M = fixedrankembeddedfactory(2, 2, 1);
    % X = struct('U', [1;0], 'V', [1;0], 'S', 1);
    % V = struct('Up', [0;1], 'Vp', [0;1], 'M', 1);
    % entry = @(M) M(1, 1);
    % mat = @(X) X.U*X.S*X.V';
    % g = @(t) entry(mat(M.retr(X, V, t)));
    % ezplot(g, [-2, 2]);
    %
    M.retr = @retraction;
    function Y = retraction(X, Z, t)
        if nargin < 3
            t = 1.0;
        end

        % Mathematically, Z.Up is orthogonal to X.U, and likewise for
        % Z.Vp compared to X.V. Thus, in principle, we could call QR
        % on Z.Up and Z.Vp alone, which should be about 4 times faster
        % than the calls here where we orthonormalize twice as many
        % vectors. However, when Z.Up, Z.Vp are poorly conditioned,
        % orthonormalizing them can lead to loss of orthogonality
        % against X.U, X.V.
        [Qu, Ru] = qr([X.U, Z.Up], 0);
        [Qv, Rv] = qr([X.V, Z.Vp], 0);
        
        % Calling svds or svd should yield the same result, but BV
        % advocated svd is more robust, and it doesn't change the
        % asymptotic complexity to call svd then trim rather than call
        % svds. Also, apparently Matlab calls ARPACK in a suboptimal way
        % for svds in this scenario.
        % Notice that the parameter t appears only here. Thus, in princple,
        % we could make some savings for line-search procedures where we
        % retract the same vector multiple times, only with different
        % values of t. The asymptotic complexity remains the same though
        % (up to a constant factor) because of the matrix-matrix products
        % below which cost essentially the same as the QR factorizations.
        [U, S, V] = svd(Ru*[X.S + t*Z.M, t*eye(k); t*eye(k), zeros(k)]*Rv');
    
        Y.U = Qu*U(:, 1:k); 
        Y.V = Qv*V(:, 1:k); 
        Y.S = S(1:k, 1:k);
        
        % Equivalent but very slow code
        % [U, S, V] = svds(X.U*X.S*X.V' + t*(X.U*Z.M*X.V' + Z.Up*X.V' + X.U*Z.Vp'), k);
        % Y.U = U; Y.V = V; Y.S = S;
    end


    % Orthographic retraction provided by Teng Zhang. One interest of the
    % orthographic retraction is that if matrices are represented in full
    % size, it can be computed without any SVDs. If for an application it
    % makes sense to represent the matrices in full size, this may be a
    % good idea, but it won't shine in the present implementation of the
    % manifold.
    M.retr_ortho = @retraction_orthographic;
    function Y = retraction_orthographic(X, Z, t)
        if nargin < 3
            t = 1.0;
        end
        
        % First, write Y (the output) as U1*S0*V1', where U1 and V1 are
        % orthogonal matrices and S0 is of size r by r.
        [U1, ~] = qr(t*(X.U*Z.M  + Z.Up) + X.U*X.S, 0);
        [V1, ~] = qr(t*(X.V*Z.M' + Z.Vp) + X.V*X.S, 0);
        S0 = (U1'*X.U)*(X.S + t*Z.M)*(X.V'*V1) ...
                         + t*((U1'*Z.Up)*(X.V'*V1) + (U1'*X.U)*(Z.Vp'*V1));
        
        % Then, obtain the singular value decomposition of Y.
        [U2, S2, V2] = svd(S0);
        Y.U = U1*U2;
        Y.S = S2;
        Y.V = V1*V2;
        
    end


    % Less safe but much faster checksum, June 24, 2014.
    % Older version right below.
    M.hash = @(X) ['z' hashmd5([sum(X.U(:)) ; sum(X.S(:)); sum(X.V(:)) ])];
    %M.hash = @(X) ['z' hashmd5([X.U(:) ; X.S(:) ; X.V(:)])];
    
    M.rand = @random;
    % Factors U, V live on Stiefel manifolds: reuse their random generator.
    stiefelm = stiefelfactory(m, k);
    stiefeln = stiefelfactory(n, k);
    function X = random()
        X.U = stiefelm.rand();
        X.V = stiefeln.rand();
        X.S = diag(sort(rand(k, 1), 1, 'descend'));
    end
    
    % Generate a random tangent vector at X.
    % Note: this may not be the uniform distribution over the set of
    % unit-norm tangent vectors.
    M.randvec = @randomvec;
    function Z = randomvec(X)
        Z.M  = randn(k);
        Z.Up = randn(m, k);
        Z.Vp = randn(n, k);
        Z = tangent(X, Z);
        nrm = M.norm(X, Z);
        Z.M  = Z.M  / nrm;
        Z.Up = Z.Up / nrm;
        Z.Vp = Z.Vp / nrm;
    end
    
    M.lincomb = @lincomb;
    
    M.zerovec = @(X) struct('M', zeros(k, k), 'Up', zeros(m, k), ...
                                              'Vp', zeros(n, k));
    
    % New vector transport on June 24, 2014 (as indicated by Bart)
    % Reference: Absil, Mahony, Sepulchre 2008 section 8.1.3:
    % For Riemannian submanifolds of a Euclidean space, it is acceptable to
    % transport simply by orthogonal projection of the tangent vector
    % translated in the ambient space.
    M.transp = @project_tangent;
    function Z2 = project_tangent(X1, X2, Z1)
        Z2 = projection(X2, tangent2ambient(X1, Z1));
    end

    % The function 'vec' is isometric from the tangent space at X to real
    % vectors of length k(m+n+k). The function 'mat' is the left-inverse
    % of 'vec'. It is sometimes useful to apply 'tangent' to the output of
    % 'mat'.
    M.vec = @vec;
    function Zvec = vec(X, Z) %#ok<INUSL>
        A = Z.M;
        B = Z.Up;
        C = Z.Vp;
        Zvec = [A(:) ; B(:) ; C(:)];
    end
    rangeM = 1:(k^2);
    rangeUp = (k^2)+(1:(m*k));
    rangeVp = (k^2+m*k)+(1:(n*k));
    M.mat = @(X, Zvec) struct('M',  reshape(Zvec(rangeM),  [k, k]), ...
                              'Up', reshape(Zvec(rangeUp), [m, k]), ...
                              'Vp', reshape(Zvec(rangeVp), [n, k]));
    M.vecmatareisometries = @() true;
    
    
    % It is sometimes useful to switch between representation of matrices
    % as triplets or as full matrices of size m x n. The function to
    % convert a matrix to a triplet, matrix2triplet, allows to specify the
    % rank of the representation. By default, it is equal to k. Omit the
    % second input (or set to inf) to get a full SVD triplet (in economy
    % format). If so, the resulting triplet does not represent a point on
    % the manifold.
    M.matrix2triplet = @matrix2triplet;
    function X_triplet = matrix2triplet(X_matrix, r)
        if ~exist('r', 'var') || isempty(r) || r <= 0
            r = k;
        end
        if r < min(m, n)
            [U, S, V] = svds(X_matrix, r);
        else
            [U, S, V] = svd(X_matrix, 'econ');
        end
        X_triplet.U = U;
        X_triplet.S = S;
        X_triplet.V = V;
    end
    M.triplet2matrix = @triplet2matrix;
    function X_matrix = triplet2matrix(X_triplet)
        U = X_triplet.U;
        S = X_triplet.S;
        V = X_triplet.V;
        X_matrix = U*S*V';
    end

end

% Linear combination of tangent vectors
function d = lincomb(x, a1, d1, a2, d2) %#ok<INUSL>

    if nargin == 3
        d.Up = a1*d1.Up;
        d.Vp = a1*d1.Vp;
        d.M  = a1*d1.M;
    elseif nargin == 5
        d.Up = a1*d1.Up + a2*d2.Up;
        d.Vp = a1*d1.Vp + a2*d2.Vp;
        d.M  = a1*d1.M  + a2*d2.M;
    else
        error('fixedrank.lincomb takes either 3 or 5 inputs.');
    end

end
