function M = hyperbolicfactory(n, m, transposed)
% Factory for matrices whose columns live on the hyperbolic manifold
%
% function M = hyperbolicfactory(n)
% function M = hyperbolicfactory(n, m)
% function M = hyperbolicfactory(n, m, transposed)
%
% Returns a structure M which describes the hyperbolic manifold in Manopt.
% A point on the manifold is a matrix X of size (n+1)-by-m whose columns
% live on the hyperbolic manifold, that is, for each column x of X, we have
%
%   -x(1)^2 + x(2)^2 + x(3)^2 + ... + x(n+1)^2 = -1.
%
% The positive branch is selected by M.rand(), that is, x(1) > 0, but all
% tools work on the negative branch as well.
%
% Equivalently, defining the Minkowski (semi) inner product
%
%   <x, y> = -x(1)y(1) + x(2)y(2) + x(3)y(3) + ... + x(n+1)y(n+1)
%
% and the induced Minkowski (semi) norm ||x||^2 = <x, x>, we can write
% compactly that each column of X has squared Minkowski norm equal to -1.
%
% The set of matrices X that satisfy this constraint is a smooth manifold.
% Tangent vectors at X are matrices U of the same size as X. If x and u are
% the kth columns of X and U respectively, then <x, u> = 0.
%
% This manifold is turned into a Riemannian manifold by restricting the
% Minkowski inner product to each tangent space (a simple calculation
% confirms that this metric is indeed Riemannian and not just semi
% Riemannian, that is, it is positive definite when restricted to each
% tangent space). This is the hyperbolic manifold: for m = 1, all of its
% sectional curvatures are equal to -1. This is called the hyperboloid or
% the Lorentz geometry.
%
% This manifold is an embedded submanifold of Euclidean space (the set of
% matrices of size (n+1)-by-m equipped with the usual trace inner product).
% Thus, when defining the Euclidean gradient for example (problem.egrad),
% it should be specified as if the function were defined in Euclidean space
% directly. The tool M.egrad2rgrad will automatically convert that gradient
% to the correct Riemannian gradient, as needed to satisfy the metric. The
% same is true for the Euclidean Hessian and other tools that manipulate
% elements in the embedding space.
%
% Importantly, the resulting manifold is /not/ a Riemannian submanifold of
% Euclidean space, because its metric is not obtained simply by restricting
% the Euclidean metric to the tangent spaces. However, it is a
% semi-Riemannian submanifold of Minkowski space, that is, the set of
% matrices of size (n+1)-by-m equipped with the Minkowski inner product.
% Minkowski space itself can be seen as a (linear) semi-Riemannian manifold
% embedded in Euclidean space. This view is entirely equivalent to the one
% described above (the Riemannian structure of the resulting manifold is
% exactly the same), and it is useful to derive some of the tools this
% factory provides.
%
% If transposed is set to true (it is false by default), then the matrices
% are transposed: a point X on the manifold is a matrix of size m-by-(n+1)
% and each row is an element in hyperbolic space. It is the same geometry,
% just a different representation.
%
%
% Resources:
%
% 1. Nickel and Kiela, "Learning Continuous Hierarchies in the Lorentz
%    Model of Hyperbolic Geometry", ICML, 2018.
%
% 2. Wilson and Leimeister, "Gradient descent in hyperbolic space",
%    arXiv preprint arXiv:1805.08207 (2018).
% 
% 3. Pennec, "Hessian of the Riemannian squared distance", HAL INRIA, 2017.
%
% Ported primarily from the McTorch toolbox at
% https://github.com/mctorch/mctorch.
%
% See also: poincareballfactory spherefactory obliquefactory obliquecomplexfactory


% This file is part of Manopt: www.manopt.org.
% Original authors: Bamdev Mishra <bamdevm@gmail.com>, Mayank Meghwanshi, 
% Pratik Jawanpuria, Anoop Kunchukuttan, and Hiroyuki Kasai Oct 28, 2018.
% Contributors: Nicolas Boumal
% Change log:
%   May 14, 2020 (NB):
%       Clarified comments about distance computation.
%   July 13, 2020 (NB):
%       Added pairmean function.

    % Design note: all functions that are defined here but not exposed
    % outside work for non-transposed representations. Only the wrappers
    % that eventually expose functionalities handle transposition. This
    % makes it easier to compose functions internally.

    if ~exist('m', 'var') || isempty(m)
        m = 1;
    end 
    
    if ~exist('transposed', 'var') || isempty(transposed)
        transposed = false;
    end
    
    if transposed
        trnsp = @(X) X';
        trnspstr = ', transposed';
    else
        trnsp = @(X) X;
        trnspstr = '';
    end

    M.name = @() sprintf('Hyperbolic manifold H(%d, %d)%s', n, m, trnspstr);
    
    M.dim = @() n*m;
    
    M.typicaldist = @() sqrt(n*m);

    % Returns a row vector q such that q(k) is the Minkowski inner product
    % of columns U(:, k) and V(:, k). This is defined in all of Minkowski
    % space, not only on tangent spaces. In particular, if X is a point on
    % the manifold, then inner_minkowski_columns(X, X) should return a
    % vector of all -1's.
    function q = inner_minkowski_columns(U, V)
        q = -U(1, :).*V(1, :) + sum(U(2:end, :).*V(2:end, :), 1);
    end
    
    % Riemannian metric: we sum over the m copies of the hyperbolic
    % manifold, each equipped with a restriction of the Minkowski metric.
    M.inner = @(X, U, V) sum(inner_minkowski_columns(trnsp(U), trnsp(V)));
    
    % Mathematically, the Riemannian metric is positive definite, hence
    % M.inner always returns a nonnegative number when U is tangent at X.
    % Numerically, because the inner product involves a difference of
    % positive numbers, round-off may result in a small negative number.
    % Taking the max against 0 avoids imaginary results.
    M.norm = @(X, U) sqrt(max(M.inner(X, U, U), 0));
    
    M.dist = @(X, Y) norm(dists(trnsp(X), trnsp(Y)));
    % This function returns a row vector of length m such that d(k) is the
    % geodesic distance between X(:, k) and Y(:, k).
    function d = dists(X, Y)
        % Mathematically, each column of U = X-Y has nonnegative squared
        % Minkowski norm. To avoid potentially imaginary results due to
        % round-off errors, we take the max against 0.
        U = X-Y;
        mink_sqnorms = max(0, inner_minkowski_columns(U, U));
        mink_norms = sqrt(mink_sqnorms);
        d = 2*asinh(.5*mink_norms);
        % The formula above is equivalent to
        % d = max(0, real(acosh(-inner_minkowski_columns(X, Y))));
        % but is numerically more accurate when distances are small.
        % When distances are large, it is better to use the acosh formula.
    end
    
    M.proj = @(X, U) trnsp(projection(trnsp(X), trnsp(U)));
    function PU = projection(X, U)
        inners = inner_minkowski_columns(X, U);
        PU = U + bsxfun(@times, X, inners);
    end
    
    M.tangent = M.proj;
    
    % For Riemannian submanifolds, converting the Euclidean gradient into
    % the Riemannian gradient amounts to an orthogonal projection. Here
    % however, the manifold is not a Riemannian submanifold of Euclidean
    % space, hence extra corrections are required to account for the change
    % of metric.
    M.egrad2rgrad = @(X, egrad) trnsp(egrad2rgrad(trnsp(X), trnsp(egrad)));
    function rgrad = egrad2rgrad(X, egrad)
        egrad(1, :) = -egrad(1, :);
        rgrad = projection(X, egrad);
    end
    
    M.ehess2rhess = @(X, egrad, ehess, U) ...
        trnsp(ehess2rhess(trnsp(X), trnsp(egrad), trnsp(ehess), trnsp(U)));
    function rhess = ehess2rhess(X, egrad, ehess, U)
        egrad(1, :) = -egrad(1, :);
        ehess(1, :) = -ehess(1, :);
        inners = inner_minkowski_columns(X, egrad);
        rhess = projection(X, bsxfun(@times, U, inners) + ehess);
    end
    
    % For the exponential, we cannot separate trnsp() nicely from the main
    % function because the third input, t, is optional.
    M.exp = @exponential;
    function Y = exponential(X, U, t)
        X = trnsp(X);
        U = trnsp(U);
        
        if nargin < 3
            tU = U;   % corresponds to t = 1
        else
            tU = t*U;
        end
        
        % Compute the individual Minkowski norms of the columns of U.
        mink_inners = inner_minkowski_columns(tU, tU);
        mink_norms = sqrt(max(0, mink_inners));
        
        % Coefficients for the exponential. For b, note that NaN's appear
        % when an element of mink_norms is zero, in which case the correct
        % convention is to define sinh(0)/0 = 1.
        a = cosh(mink_norms);
        b = sinh(mink_norms)./mink_norms;
        b(isnan(b)) = 1;
        
        Y = bsxfun(@times, X, a) + bsxfun(@times, tU, b);

        Y = trnsp(Y);
    end
    
    M.retr = M.exp;
    
    M.log = @(X, Y) trnsp(logarithm(trnsp(X), trnsp(Y)));
    function U = logarithm(X, Y)
        d = dists(X, Y);
        a = d./sinh(d);
        a(isnan(a)) = 1;
        U = projection(X, bsxfun(@times, Y, a));
    end

    M.hash = @(X) ['z' hashmd5(X(:))];
    
    M.rand = @() trnsp(myrand());
    function X = myrand()
        X1 = randn(n, m);
        x0 = sqrt(1 + sum(X1.^2, 1)); % selects positive branch
        X = [x0; X1];
    end
    
    M.normalize = @(X, U) U / M.norm(X, U);
    M.randvec = @(X) M.normalize(X, M.proj(X, randn(size(X))));
    
    M.lincomb = @matrixlincomb;
    
    M.zerovec = @(X) zeros(size(X));
    
    M.transp = @(X1, X2, U) M.proj(X2, U);
    
    M.pairmean = @(x1, x2) M.exp(x1, M.log(x1, x2), .5);
   
    % vec returns a vector representation of an input tangent vector which
    % is represented as a matrix; mat returns the original matrix
    % representation of the input vector representation of a tangent
    % vector; vec and mat are thus inverse of each other.
    vect = @(X) X(:);
    M.vec = @(X, U_mat) vect(trnsp(U_mat));
    M.mat = @(X, U_vec) trnsp(reshape(U_vec, [n+1, m]));
    M.vecmatareisometries = @() false;

end
