function M = unitaryfactory(n, k)
% Returns a manifold structure to optimize over unitary matrices.
% 
% function M = unitaryfactory(n)
% function M = unitaryfactory(n, k)
%
% Unitary group: deals with arrays U of size n x n x k (or n x n if k = 1,
% which is the default) such that each n x n matrix is unitary, that is,
%     X'*X = eye(n) if k = 1, or
%     X(:, :, i)' * X(:, :, i) = eye(n) for i = 1 : k if k > 1.
%
% This is a description of U(n)^k with the induced metric from the
% embedding space (C^nxn)^k, i.e., this manifold is a Riemannian
% submanifold of (C^nxn)^k endowed with the usual real inner product on
% C^nxn, namely, <A, B> = real(trace(A'*B)).
%
% This is important:
% Tangent vectors are represented in the Lie algebra, i.e., as
% skew-Hermitian matrices. Use the function M.tangent2ambient(X, H) to
% switch from the Lie algebra representation to the embedding space
% representation. This is often necessary to define problem.ehess(X, H),
% as the input H will then be a skew-Hermitian matrix (but the output must
% not be, as the output is the Hessian in the embedding Euclidean space.)
%
% By default, the retraction is only a first-order approximation of the
% exponential. To force the use of a second-order approximation, call
% M.retr = M.retr_polar after creating M. This switches from a QR-based
% computation to an SVD-based computation.
%
% By default, k = 1.
%
% See also: stiefelcomplexfactory rotationsgroup stiefelfactory

% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, June 18, 2019.
% Contributors: 
% Change log:

    
    if ~exist('k', 'var') || isempty(k)
        k = 1;
    end
    
    if k == 1
        M.name = @() sprintf('Unitary manifold U(%d)', n);
    elseif k > 1
        M.name = @() sprintf('Product unitary manifold U(%d)^%d', n, k);
    else
        error('k must be an integer no less than 1.');
    end
    
    M.dim = @() k*(n^2);
    
    M.inner = @(x, d1, d2) real(d1(:)'*d2(:));
    
    M.norm = @(x, d) norm(d(:));
    
    M.typicaldist = @() pi*sqrt(n*k);
    
    M.proj = @(X, H) multiskewh(multiprod(multihconj(X), H));
    
    M.tangent = @(X, H) multiskewh(H);
    
    M.tangent2ambient_is_identity = false;
    M.tangent2ambient = @(X, U) multiprod(X, U);
    
    M.egrad2rgrad = M.proj;
    
    M.ehess2rhess = @ehess2rhess;
    function Rhess = ehess2rhess(X, Egrad, Ehess, H)
        % Reminder : H contains skew-Hermitian matrices. The actual
        % direction that the point X is moved along is X*H.
        Xt = multihconj(X);
        XtEgrad = multiprod(Xt, Egrad);
        symXtEgrad = multiherm(XtEgrad);
        XtEhess = multiprod(Xt, Ehess);
        Rhess = multiskewh( XtEhess - multiprod(H, symXtEgrad) );
    end
    
    % This QR-based retraction is only a first-order approximation
    % of the exponential map, not a second-order one.
    M.retr_qr = @retraction_qr;
    function Y = retraction_qr(X, U, t)
        % It is necessary to call qr_unique rather than simply qr to ensure
        % this is a retraction, to avoid spurious column sign flips.
        XU = multiprod(X, U);
        if nargin < 3
            Y = qr_unique(X + XU);
        else
            Y = qr_unique(X + t*XU);
        end
    end
    
    % A second-order retraction is implemented here. To force its use,
    % after creating the factory M, execute M.retr = M.retr_polar.
    M.retr_polar = @retraction_polar;
    function Y = retraction_polar(X, U, t)
        if nargin == 3
            tU = t*U;
        else
            tU = U;
        end
        Y = X + multiprod(X, tU);
        for kk = 1 : k
            [Uk, ~, Vk] = svd(Y(:, :, kk));
            Y(:, :, kk) = Uk*Vk';
        end
    end

    % By default, use QR retraction
    M.retr = M.retr_qr;
    
    M.exp = @exponential;
    function Y = exponential(X, U, t)
        if nargin == 3
            exptU = t*U;
        else
            exptU = U;
        end
        for kk = 1 : k
            exptU(:, :, kk) = expm(exptU(:, :, kk));
        end
        Y = multiprod(X, exptU);
    end
    
    M.log = @logarithm;
    function U = logarithm(X, Y)
        U = multiprod(multihconj(X), Y);
        for kk = 1 : k
            U(:, :, kk) = logm(U(:, :, kk));
        end
        % Ensure the tangent vector is in the Lie algebra.
        U = multiskewh(U);
    end

    M.hash = @(X) ['z' hashmd5([real(X(:)) ; imag(X(:))])];
    
    M.rand = @() randunitary(n, k);
    
    M.randvec = @randomvec;
    function U = randomvec(X) %#ok<INUSD>
        U = randskewh(n, k);
        nrmU = sqrt(U(:)'*U(:));
        U = U / nrmU;
    end
    
    M.lincomb = @matrixlincomb;
    
    M.zerovec = @(x) zeros(n, n, k);
    
    M.transp = @(x1, x2, d) d;
    M.isotransp = M.transp; % the transport is isometric
    
    M.pairmean = @pairmean;
    function Y = pairmean(X1, X2)
        V = M.log(X1, X2);
        Y = M.exp(X1, .5*V);
    end
    
    M.dist = @(x, y) M.norm(x, M.log(x, y));
    
    sz = n*n*k;
    M.vec = @(x, u_mat) [real(u_mat(:)) ; imag(u_mat(:))];
    M.mat = @(x, u_vec) reshape(u_vec(1:sz), [n, n, k]) ...
                        + 1i*reshape(u_vec((sz+1):end), [n, n, k]);
    M.vecmatareisometries = @() true;

end
