function M = realphasefactory(n, z0, zmax)
% Returns a manifold struct to optimize over phases of fft's of real signals
%
% function M = realphasefactory(n)
% function M = realphasefactory(n, z0)
% function M = realphasefactory(n, z0, zmax)
%
% If x is a real vector of length n, then y = fft(x) is a complex vector
% which obeys certain symmetries. Specifically, for any integer k,
%
%  y(1+mod(k, n)) = conj(y(1+mod(n-k, n)))
%
% The same holds for the phases of the Fourier transform z = sign(y).
%
% This factory returns a Manopt manifold structure which represents the set
% of complex vectors z of length n which could be the phases of the Fourier
% transform of a real signal of length n:
%
%   abs(z) = 1   and   z(1+mod(k, n)) = conj(z(1+mod(n-k, n))) for each k.
%
% For k = 1, this readily implies that z(1) is +1 or -1, so that the set of
% possible z's is disconnected. To choose which connected component to work
% with, set the second input z0 to +1 or -1 (this is the sign of the mean
% of x). By default, z0 = 1.
%
% Furthermore, if n is even, then k = n/2 implies z(1+n/2) is +1 or -1 as
% well, thus further disconnecting the set of acceptable z's. To choose
% which component to work with, set the third input zmax to +1 or -1. By
% default, it is +1.
%
% The Riemannian manifold structure is the Riemannian submanifold
% structure from the embedding space R^2 x ... x R^2, i.e., the complex
% circles are identified with the unit circle in the real plane.
% Concretely, this means the inner product is <u, v>_z = real(u'*v).
% Tangent vectors at z are complex vectors of length n which notably
% satisfy z(1+0) = 0 and, if n is even, z(1+n/2) = 0.
%
% n must be integer and n >= 3 (for n = 1:2 the manifold has dimension 0).
%
% Extra functions available in M include M.up, M.down and M.downup. They
% allow to capture the symmetries concisely, as:
%
%    M.up(z) == conj(M.down(z)).
%
% See in code for more details.
%
% See also complexcirclefactory

% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, Feb. 2, 2017.
% Contributors: joint work with Tamir Bendory, Zhizhen Zhao and Amit Singer
% Change log: 
%
%   July 20, 2017 (NB)
%       The distance function is now more accurate. Improved logarithm
%       accordingly.

    assert(n == round(n) && n >= 3, 'n must be an integer >= 3.');
    
    even_n = (round(n/2) == n/2);
    
    if ~exist('z0', 'var') || isempty(z0)
        z0 = 1;
    end
    if ~exist('zmax', 'var') || isempty(zmax)
        zmax = 1;
    end
    
    assert(z0 == 1 || z0 == -1, 'z0 must be +1 or -1.');
    assert(zmax == 1 || zmax == -1, 'zmax must be +1 or -1.');

    if even_n
        M.name = @() sprintf('Phases of fft''s of real signals of length %d (z0 = %d, zmax = %d)', n, z0, zmax);
    else
        M.name = @() sprintf('Phases of fft''s of real signals of length %d (z0 = %d)', n, z0);
    end
    
    M.dim = @() floor((n-1)/2);
    
    M.inner = @(z, v, w) real(v'*w);
    
    M.norm = @(z, u) norm(u);
    
    M.dist = @(z1, z2) norm(real(2*asin(.5*abs(z1 - z2))));
    
    M.typicaldist = @() pi*sqrt(n/2);
    
    % Special functions to ease working with the symmetries.
    down = @(u) u;
    up = @(u) u([1 ; (n:-1:2)']);
    downup = @(u) (down(u) + conj(up(u)))/2;
    M.down = down;
    M.up = up;
    M.downup = downup;
    
    M.proj = @proj;
    function pu = proj(z, u)
        duu = downup(u);
        pu = duu - real(duu .* conj(z)).*z;
        % Note that there is no need to enforce pu(1) = 0 or (if n is even)
        % pu(1+n/2) = 0 manually, since the IEEE standard ensures that the
        % above operation will be exact for those entries provided z(1)
        % (and possibly z(1+n/2) is +1 or -1, as should be the case.
    end
    
    M.tangent = M.proj;
    
    % For Riemannian submanifolds, converting a Euclidean gradient into a
    % Riemannian gradient amounts to an orthogonal projection.
    M.egrad2rgrad = M.proj;
    
    M.ehess2rhess = @ehess2rhess;
    function rhess = ehess2rhess(z, egrad, ehess, zdot)
        rhess = M.proj(z, ehess - real(downup(egrad) .* conj(z)).*zdot);
    end
    
    M.exp = @exponential;
    function y = exponential(z, v, t)
        
        if nargin == 2
            % t = 1;
            tv = v;
        else
            tv = t*v;
        end

        y = zeros(n, 1);

        nrm_tv = abs(tv);
        
        % We need to be careful for zero steps.
        mask = (nrm_tv > 0);
        y(mask) = z(mask).*cos(nrm_tv(mask)) + ...
                  tv(mask).*(sin(nrm_tv(mask))./nrm_tv(mask));
        y(~mask) = z(~mask);
        
    end
    
    M.retr = @retraction;
    function y = retraction(z, v, t)
        if nargin == 2
            % t = 1;
            tv = v;
        else
            tv = t*v;
        end
        y = sign(z+tv);
    end

    M.log = @logarithm;
    function v = logarithm(x1, x2)
        v = M.proj(x1, x2 - x1);
        di = real(2*asin(.5*abs(x1 - x2)));
        nv = abs(v);
        factors = di ./ nv;
        factors(di <= 1e-6) = 1;
        v = v .* factors;
    end
    
    M.hash = @(z) ['z' hashmd5( [real(z(:)) ; imag(z(:))] ) ];
    
    M.rand = @random;
    function z = random()
        z = sign(downup(randn(n, 1) + 1i*randn(n, 1)));
        z(1) = z0;
        if even_n
            z(1 + n/2) = zmax;
        end
    end
    
    M.randvec = @randomvec;
    function v = randomvec(z)
        v = M.proj(z, randn(n, 1) + 1i*randn(n, 1));
        v = v / norm(v);
    end
    
    M.lincomb = @matrixlincomb;
    
    M.zerovec = @(z) zeros(n, 1);
    
    M.transp = @(z1, z2, u) M.proj(z2, u);
    
    M.pairmean = @pairmean;
    function z = pairmean(z1, z2)
        z = sign(z1+z2);
    end

    % This vec/mat pair is an isometry which allows to switch between the
    % classical representation of tangent vectors---as complex vectors of
    % length n---to real vectors of length M.dim() whose entries are the
    % coordinates of the tangent vector in the basis 1i*z, for the first
    % half. A scaling of sqrt(2) is applied to ensure isometry, since
    % tangent vectors are represented with only half of their entries.
    I = 2 : floor((n+1)/2);
    if even_n
        middle = 0;
    else
        middle = [];
    end
    M.vec = @(z, u_mat) sqrt(2)*real(u_mat(I) .* conj(1i*z(I)));
    M.mat = @(z, u_vec) [0 ; u_vec.*(1i*z(I)) ; middle ; ...
                             flipud(conj(u_vec.*(1i*z(I))))]/sqrt(2);
    M.vecmatareisometries = @() true;
    
end
