function M = complexcirclefactory(n, m, gpuflag)
% Returns a manifold struct to optimize over unit-modulus complex numbers.
%
% function M = complexcirclefactory()
% function M = complexcirclefactory(n)
% function M = complexcirclefactory(n, m)
% function M = complexcirclefactory(n, m, gpuflag)
%
% Description of matrices z in C^(nxm) (complex) such that each entry
% z(i, j) has unit modulus. The manifold structure is the Riemannian
% submanifold structure from the embedding space R^2 to the power n-by-m,
% i.e., the complex circle is identified with the unit circle in the real
% plane. Points and tangent vectors are represented as complex matrices of
% size n-by-m.
%
% Set gpuflag = true to have points, tangent vectors and ambient vectors
% stored on the GPU. If so, computations can be done on the GPU directly.
%
% By default, n = 1, m = 1 and gpuflag = false.
%
% See also spherecomplexfactory

% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, Dec. 30, 2012.
% Contributors: 
% Change log: 
%
%   July 7, 2014 (NB): Added ehess2rhess function.
%
%   Sep. 3, 2014 (NB): Correction to the dist function (extract real part).
%
%   April 13, 2015 (NB): Fixed logarithm.
%
%   Oct. 8, 2016 (NB)
%       Code for exponential was simplified to only treat the zero vector
%       as a particular case.
%
%   July 20, 2017 (NB)
%       The distance function is now even more accurate. Improved logarithm
%       accordingly.
%
%   July 18, 2018 (NB)
%       Added inverse retraction function M.invretr.
%
%   Aug. 3, 2018 (NB)
%       Added support for matrices of unit-modulus (as opposed to vectors).
%       Added GPU support: just set gpuflag = true.
    
    if ~exist('n', 'var') || isempty(n)
        n = 1;
    end
    if ~exist('m', 'var') || isempty(m)
        m = 1;
    end
    if ~exist('gpuflag', 'var') || isempty(gpuflag)
        gpuflag = false;
    end
    
    % If gpuflag is active, new arrays (e.g., via rand, randn, zeros, ones)
    % are created directly on the GPU; otherwise, they are created in the
    % usual way (in double precision).
    if gpuflag
        array_type = 'gpuArray';
    else
        array_type = 'double';
    end

    if m == 1
        M.name = @() sprintf('Complex circle (S^1)^%d', n);
    else
        M.name = @() sprintf('Complex circle (S^1)^(%dx%d)', n, m);
    end
    
    M.dim = @() n*m;
    
    M.inner = @(z, v, w) real(v(:)'*w(:));
    
    M.norm = @(x, v) norm(v, 'fro');
    
    M.dist = @(x, y) norm(real(2*asin(.5*abs(x - y))), 'fro');
    
    M.typicaldist = @() pi*sqrt(n*m);
    
    M.proj = @(z, u) u - real( conj(u) .* z ) .* z;
    
    M.tangent = M.proj;
    
    % For Riemannian submanifolds, converting a Euclidean gradient into a
    % Riemannian gradient amounts to an orthogonal projection.
    M.egrad2rgrad = M.proj;
    
    M.ehess2rhess = @ehess2rhess;
    function rhess = ehess2rhess(z, egrad, ehess, zdot)
        rhess = M.proj(z, ehess - real(z.*conj(egrad)).*zdot);
    end
    
    M.exp = @exponential;
    function y = exponential(z, v, t)
        
        if nargin == 2
            % t = 1;
            tv = v;
        else
            tv = t*v;
        end

        y = zeros(n, m, array_type);

        nrm_tv = abs(tv);
        
        % We need to be careful for zero steps.
        mask = (nrm_tv > 0);
        y(mask) = z(mask).*cos(nrm_tv(mask)) + ...
                  tv(mask).*(sin(nrm_tv(mask))./nrm_tv(mask));
        y(~mask) = z(~mask);
        
    end
    
    M.retr = @retraction;
    function y = retraction(z, v, t)
        if nargin == 2
            % t = 1;
            tv = v;
        else
            tv = t*v;
        end
        y = sign(z+tv);
    end

    M.invretr = @inverse_retraction;
    function v = inverse_retraction(x, y)
        v = y ./ real(conj(x) .* y) - x;
    end

    M.log = @logarithm;
    function v = logarithm(x1, x2)
        v = M.proj(x1, x2 - x1);
        di = real(2*asin(.5*abs(x1 - x2)));
        nv = abs(v);
        factors = di ./ nv;
        factors(di <= 1e-10) = 1;
        v = v .* factors;
    end
    
    M.hash = @(z) ['z' hashmd5( [real(z(:)) ; imag(z(:))] ) ];
    
    M.rand = @random;
    function z = random()
        z = sign(randn(n, m, array_type) + 1i*randn(n, m, array_type));
    end
    
    M.randvec = @randomvec;
    function v = randomvec(z)
        % i*z(k) is a basis vector of the tangent vector to the k-th circle
        v = randn(n, m, array_type) .* (1i*z);
        v = v / norm(v, 'fro');
    end
    
    M.lincomb = @matrixlincomb;
    
    M.zerovec = @(x) zeros(n, m, array_type);
    
    M.transp = @(x1, x2, d) M.proj(x2, d);
    
    M.pairmean = @pairmean;
    function z = pairmean(z1, z2)
        z = sign(z1+z2);
    end

    M.vec = @(x, u_mat) [real(u_mat(:)) ; imag(u_mat(:))];
    M.mat = @(x, u_vec) reshape(u_vec(1:(n*m)) + 1i*u_vec((n*m+1):end), [n, m]);
    M.vecmatareisometries = @() true;

    % Automatically convert a number of tools to support GPU.
    if gpuflag
        M = factorygpuhelper(M);
    end
    
end
