function M = spherefactory(n, m, gpuflag)
% Returns a manifold struct to optimize over unit-norm vectors or matrices.
%
% function M = spherefactory(n)
% function M = spherefactory(n, m)
% function M = spherefactory(n, m, gpuflag)
%
% Manifold of n-by-m real matrices of unit Frobenius norm.
% By default, m = 1, which corresponds to the unit sphere in R^n. The
% metric is such that the sphere is a Riemannian submanifold of the space
% of nxm matrices with the usual trace inner product, i.e., the usual
% metric.
%
% Set gpuflag = true to have points, tangent vectors and ambient vectors
% stored on the GPU. If so, computations can be done on the GPU directly.
%
% See also: obliquefactory spherecomplexfactory

% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, Dec. 30, 2012.
% Contributors:
% Change log:
%
%   Oct. 8, 2016 (NB)
%       Code for exponential was simplified to only treat the zero vector
%       as a particular case.
%
%   Oct. 22, 2016 (NB)
%       Distance function dist now significantly more accurate for points
%       within 1e-7 and less from each other.
%
%   July 20, 2017 (NB)
%       Following conversations with Bruno Iannazzo and P.-A. Absil,
%       the distance function is now even more accurate.
%
%   Sep. 7, 2017 (NB)
%       New isometric transporter available in M.isotransp,
%       contributed by Changshuo Liu.
%
%   April 17, 2018 (NB)
%       ehess2rhess: Used to compute projection of ehess, then subtract a
%       multiple of u (which is assumed tangent.) Now, similarly to what
%       happens in stiefelfactory, we first subtract the multiple of u from
%       ehess, then we project. Mathematically, these operations are the
%       same. Numerically, the former version used to be better because tCG
%       in trustregions had some drift near fine convergence. Now that the
%       drift in tCG has been fixed, it is reasonable to apply the
%       projection last, to ensure best tangency of the output.
%
%   July 18, 2018 (NB)
%       Added the inverse retraction (M.invretr) for the sphere.
%
%   Aug. 3, 2018 (NB)
%       Added GPU support: just set gpuflag = true.
%
%   Jan. 8, 2021 (NB)
%       Added tangent2ambient/tangent2ambient_is_identity pair.
%
%   July 1, 2024 (NB)
%       Added M.retr2 = M.retr to mark it as a second-order retraction.
%       Made M.paralleltransp = M.isotransp available.


    if ~exist('m', 'var') || isempty(m)
        m = 1;
    end
    if ~exist('gpuflag', 'var') || isempty(gpuflag)
        gpuflag = false;
    end

    % If gpuflag is active, new arrays (e.g., via rand, randn, zeros, ones)
    % are created directly on the GPU; otherwise, they are created in the
    % usual way (in double precision).
    if gpuflag
        array_type = 'gpuArray';
    else
        array_type = 'double';
    end


    if m == 1
        M.name = @() sprintf('Sphere S^%d', n-1);
    else
        M.name = @() sprintf('Unit F-norm %dx%d matrices', n, m);
    end

    M.dim = @() n*m-1;

    M.inner = @(x, d1, d2) d1(:)'*d2(:);

    M.norm = @(x, d) norm(d, 'fro');

    M.dist = @dist;
    function d = dist(x, y)

        % The following code is mathematically equivalent to the
        % computation d = acos(x(:)'*y(:)) but is much more accurate when
        % x and y are close.

        chordal_distance = norm(x - y, 'fro');
        d = real(2*asin(.5*chordal_distance));

        % Note: for x and y almost antipodal, the accuracy is good but not
        % as good as possible. One way to improve it is by using the
        % following branching:
        % % if chordal_distance > 1.9
        % %     d = pi - dist(x, -y);
        % % end
        % It is rarely necessary to compute the distance between
        % almost-antipodal points with full accuracy in Manopt, hence we
        % favor a simpler code.

    end

    M.typicaldist = @() pi;

    M.proj = @(x, d) d - x*(x(:)'*d(:));

    M.tangent = M.proj;

    M.tangent2ambient_is_identity = true;
    M.tangent2ambient = @(x, u) u;

    % For Riemannian submanifolds, converting a Euclidean gradient into a
    % Riemannian gradient amounts to an orthogonal projection.
    M.egrad2rgrad = M.proj;

    M.ehess2rhess = @ehess2rhess;
    function rhess = ehess2rhess(x, egrad, ehess, u)
        rhess = M.proj(x, ehess - (x(:)'*egrad(:))*u);
    end

    M.exp = @exponential;

    M.retr = @retraction;
    M.retr2 = M.retr;
    M.invretr = @inverse_retraction;

    M.log = @logarithm;
    function v = logarithm(x1, x2)
        v = M.proj(x1, x2 - x1);
        di = M.dist(x1, x2);
        % If the two points are "far apart", correct the norm.
        if di > 1e-6
            nv = norm(v, 'fro');
            v = v * (di / nv);
        end
    end

    M.hash = @(x) ['z' hashmd5(x(:))];

    M.rand = @() random(n, m, array_type);

    M.randvec = @(x) randomvec(n, m, x, array_type);

    M.zerovec = @(x) zeros(n, m, array_type);

    M.lincomb = @matrixlincomb;

    M.transp = @(x1, x2, d) M.proj(x2, d);

    % Parallel transport of d from the tangent space at x1 to x2.
    % See Sec. 5 in http://epubs.siam.org/doi/pdf/10.1137/16M1069298
    % "A Riemannian Gradient Sampling Algorithm for Nonsmooth Optimization
    %  on Manifolds", by Hosseini and Uschmajew, SIOPT 2017
    M.paralleltransp = @(x1, x2, d) paralleltransp(x1, x2, d);
    M.isotransp = M.paralleltransp;
    function Td = paralleltransp(x1, x2, d)
        v = logarithm(x1, x2);
        dist_x1x2 = norm(v, 'fro');
        if dist_x1x2 > 0
            u = v / dist_x1x2;
            utd = u(:)'*d(:);
            Td = d + (cos(dist_x1x2)-1)*utd*u ...
                    -  sin(dist_x1x2)  *utd*x1;
        else
            % x1 == x2, so the transport is identity
            Td = d;
        end
    end

    M.pairmean = @pairmean;
    function y = pairmean(x1, x2)
        y = x1+x2;
        y = y / norm(y, 'fro');
    end

    M.vec = @(x, u_mat) u_mat(:);
    M.mat = @(x, u_vec) reshape(u_vec, [n, m]);
    M.vecmatareisometries = @() true;


    % Automatically convert a number of tools to support GPU.
    if gpuflag
        M = factorygpuhelper(M);
    end


end

% Exponential on the sphere
function y = exponential(x, d, t)
    
    if nargin == 2
        % t = 1
        td = d;
    else
        td = t*d;
    end
    
    nrm_td = norm(td, 'fro');
    y = x*cos(nrm_td) + td*sinxoverx(nrm_td);
    
end

% Retraction on the sphere
function y = retraction(x, d, t)

    if nargin == 2
        % t = 1;
        td = d;
    else
        td = t*d;
    end

    y = x + td;
    y = y / norm(y, 'fro');

end

% Given x and y two points on the manifold, if there exists a tangent
% vector d at x such that Retr_x(d) = y, this function returns d.
function d = inverse_retraction(x, y)

    % Since
    %   x + d = y*||x + d||
    % and x'd = 0, multiply the above by x' on the left:
    %   1 + 0 = x'y * ||x + d||
    % Then solve for d:

    d = y/(x(:)'*y(:)) - x;

end

% Uniform random sampling on the sphere.
function x = random(n, m, array_type)

    x = randn(n, m, array_type);
    x = x / norm(x, 'fro');

end

% Random normalized tangent vector at x.
function d = randomvec(n, m, x, array_type)

    d = randn(n, m, array_type);
    d = d - x*(x(:)'*d(:));
    d = d / norm(d, 'fro');

end
