function M = fixedrankfactory_tucker_preconditioned(tensor_size, tensor_rank)
% Manifold of fixed multilinear rank tensors in Tucker format.
%
% function M = fixedrankfactory_tucker_preconditioned(tensor_size, tensor_rank)
%
% n1 = tensor_size(1);
% n2 = tensor_size(2);
% n3 = tensor_size(3);
% r1 = tensor_rank(1);
% r2 = tensor_rank(2);
% r3 = tensor_rank(3);
%
% A point X on the manifold is represented as a structure with four
% fields: U1, U2, U3 and G. The matrices U1 (n1-by-r1), U2 (n2-by-r2),
% and U3 (n3-by-r3) are orthogonal matrices. G (r1-by-r2-by-r3) is a 
% multidimensional array.
%
% Tangent vectors are represented as a structure with four fields: 
% U1, U2, U3, and G.
%
% We exploit the quotient nature of Tucker decompositions to impose a
% scaled inner product on the manifold. This suits least-squares problems.
% For details, refer to the technical report:
% "{R}iemannian preconditioning for tensor completion",
% H. Kasai and B. Mishra, Arxiv preprint arXiv:1506.02159, 2015.
%
% Paper link: http://arxiv.org/abs/1506.02159.
%
% Please cite the Manopt paper as well as the research paper:
%     @TechReport{kasai2015precon,
%       Title   = {{R}iemannian preconditioning for tensor completion},
%       Author  = {Kasai, H. and Mishra, B.},
%       Journal = {Arxiv preprint arXiv:1506.02159},
%       Year    = {2015}
%     }
%
% See also fixedranktensorembeddedfactory

% Original authors: Hiroyuki Kasai and Bamdev Mishra, June 5, 2015.
% Contributors: 
% Change log:
%
%    Apr. 17, 2018 (NB):
%        Removed dependency on lyap.
%
%    Sep.  6, 2018 (NB):
%        Removed M.exp() as it was not implemented.
%
%   Jan. 4, 2021 (NB):
%       Compatibility with Octave 6.1.0. Besides some of the steps also
%       taken in other factories, a special issue here was: as M.transp
%       calls the nested function 'projection', it is important to defined
%       M.transp as an explicit nested function and not as an anonymous
%       function with @, as otherwise the scope of the mother (helper)
%       function was invisible to 'projection' when called through transp.

    if length(tensor_rank) > 3
        error('Bad usage of fixedrankfactory_tucker_preconditioned. Currently, only handles 3-order tensors.');
    end
    
    % Tensor size
    n1 = tensor_size(1);
    n2 = tensor_size(2);
    n3 = tensor_size(3);
    
    % Core size or multilinear rank
    r1 = tensor_rank(1);
    r2 = tensor_rank(2);
    r3 = tensor_rank(3);
    
    % Sparse version of identity that is used in M.proj
    speyer1 = speye(r1);
    speyer2 = speye(r2);
    speyer3 = speye(r3);
    
    M = fixedrankfactory_tucker_preconditioned_helper(...
                        tensor_size, tensor_rank, ...
                        n1, n2, n3, r1, r2, r3, speyer1, speyer2, speyer3);
                              
end


% This is the actual factory
function M = fixedrankfactory_tucker_preconditioned_helper(...
                     tensor_size, tensor_rank, ...
                     n1, n2, n3, r1, r2, r3, speyer1, speyer2, speyer3) %#ok<INUSL>

    M.name = @() sprintf('G x U1 x U2 x U3 quotient Tucker manifold of %d-by-%d-by-%d tensor of rank %d-by-%d-by-%d.', n1, n2, n3, r1, r2, r3);
    
    M.dim = @() n1*r1-r1^2 + n2*r2-r2^2 + n3*r3-r3^2 + r1*r2*r3;
    
    % Some precomputations at point X to be used in the inner product (and
    % pretty much everywhere else)
    function X = prepare(X)
        if ~all(isfield(X,{'G1G1t','G1',...
                'G2G2t','G2', ...
                'G3G3t','G3'}) == 1)
            
            X.G1 =  reshape(X.G, r1, r2*r3);
            X.G1G1t = X.G1*X.G1'; % Positive definite  
            
            
            X.G2 = reshape(permute(X.G, [2 1 3]), r2, r1*r3); 
            X.G2G2t = X.G2*X.G2'; % Positive definite  
            
            
            X.G3 = reshape(permute(X.G, [3 1 2]), r3, r1*r2);  
            X.G3G3t = X.G3*X.G3'; % Positive definite  
        end
        
        
    end
    
    % Choice of metric is motivated by symmetry and tuned to least-squares
    % cost function
    M.inner = @iproduct;
    function ip = iproduct(X, eta, zeta)
        X = prepare(X);
        ip =  trace(X.G1G1t*(eta.U1'*zeta.U1)) ...
            + trace(X.G2G2t*(eta.U2'*zeta.U2)) ...
            + trace(X.G3G3t*(eta.U3'*zeta.U3)) ...
            + (eta.G(:)'*zeta.G(:));
    end
    M.norm = @(X, eta) sqrt(M.inner(X, eta, eta));
    
    M.typicaldist = @() 10*n1*r1; % BM: To do  
    
    M.egrad2rgrad = @egrad2rgrad;
    function rgrad = egrad2rgrad(X, egrad)
        X = prepare(X); % Reuse already computed terms
        
        SSU1 = X.G1G1t;
        ASU1 = 2*symm(SSU1*(X.U1' * egrad.U1));
        
        SSU2 = X.G2G2t;
        ASU2 = 2*symm(SSU2*(X.U2' * egrad.U2));
        
        SSU3 = X.G3G3t;
        ASU3 = 2*symm(SSU3*(X.U3' * egrad.U3));
        
        
        BU1 = lyapunov_symmetric(SSU1, ASU1);
        BU2 = lyapunov_symmetric(SSU2, ASU2);
        BU3 = lyapunov_symmetric(SSU3, ASU3);
        
        % The sylvester solutions ensure that the Riemannian gradient rgrad 
        % is now on the tangent space. From the Riemannian submersion 
        % theory, it also belongs to the horizontal space. Therefore,
        % no need to further project it on the horizontal space.
        
        rgrad.U1 = (egrad.U1 - X.U1*BU1)/X.G1G1t;
        rgrad.U2 = (egrad.U2 - X.U2*BU2)/X.G2G2t;
        rgrad.U3 = (egrad.U3 - X.U3*BU3)/X.G3G3t;
        rgrad.G = egrad.G;

        
    end
    
    
    
    M.ehess2rhess = @ehess2rhess;
    function Hess = ehess2rhess(X, egrad, ehess, eta) 
        X = prepare(X); % Reuse already computed terms
        
        % Riemannian gradient
        SSU1 = X.G1G1t;
        ASU1 = 2*symm(SSU1*(X.U1' * egrad.U1));
        SSU2 = X.G2G2t;
        ASU2 = 2*symm(SSU2*(X.U2' * egrad.U2));
        SSU3 = X.G3G3t;
        ASU3 = 2*symm(SSU3*(X.U3' * egrad.U3));
        
        BU1 = lyapunov_symmetric(SSU1, ASU1);
        BU2 = lyapunov_symmetric(SSU2, ASU2);
        BU3 = lyapunov_symmetric(SSU3, ASU3);
        
        rgrad.U1 = (egrad.U1 - X.U1*BU1)/X.G1G1t;
        rgrad.U2 = (egrad.U2 - X.U2*BU2)/X.G2G2t;
        rgrad.U3 = (egrad.U3 - X.U3*BU3)/X.G3G3t;
        rgrad.G = egrad.G;
        
        % Directional derivative of Riemannian gradient
        
        eta_G1 = reshape(eta.G, r1, r2*r3); % double(tenmat(eta.G,1));
        eta_G2 = reshape(permute(eta.G, [2 1 3]), r2, r1*r3); % double(tenmat(eta.G,2));
        eta_G3 = reshape(permute(eta.G, [3 1 2]), r3, r1*r2); % double(tenmat(eta.G,3));
        egrad_G1 = reshape(egrad.G, r1, r2*r3); % double(tenmat(egrad.G,1));
        egrad_G2 = reshape(permute(egrad.G, [2 1 3]), r2, r1*r3); % double(tenmat(egrad.G,2));
        egrad_G3 = reshape(permute(egrad.G, [3 1 2]), r3, r1*r2); % double(tenmat(egrad.G,3));
        ehess_G1 = reshape(ehess.G, r1, r2*r3); % double(tenmat(ehess.G,1));
        ehess_G2 = reshape(permute(ehess.G, [2 1 3]), r2, r1*r3); % double(tenmat(ehess.G,2));
        ehess_G3 = reshape(permute(ehess.G, [3 1 2]), r3, r1*r2); % double(tenmat(ehess.G,3));
        rgrad_G1 = reshape(rgrad.G, r1, r2*r3); % double(tenmat(rgrad.G,1));
        rgrad_G2 = reshape(permute(rgrad.G, [2 1 3]), r2, r1*r3); % double(tenmat(rgrad.G,2));
        rgrad_G3 = reshape(permute(rgrad.G, [3 1 2]), r3, r1*r2); % double(tenmat(rgrad.G,3));
        
        ASU1dot = 2*symm((2*symm(X.G1*eta_G1')*(egrad_G1*X.G1')) + X.G1G1t*(ehess_G1*X.G1' + egrad_G1*eta_G1')) - 4*symm(symm(eta_G1*X.G1')*BU1);
        ASU2dot = 2*symm((2*symm(X.G2*eta_G2')*(egrad_G2*X.G2')) + X.G2G2t*(ehess_G2*X.G2' + egrad_G2*eta_G2')) - 4*symm(symm(eta_G2*X.G2')*BU2);
        ASU3dot = 2*symm((2*symm(X.G3*eta_G3')*(egrad_G3*X.G3')) + X.G3G3t*(ehess_G3*X.G3' + egrad_G3*eta_G3')) - 4*symm(symm(eta_G3*X.G3')*BU3);
        
        
        SSU1dot = X.G1G1t;
        SSU2dot = X.G2G2t;
        SSU3dot = X.G3G3t;
        BU1dot = lyapunov_symmetric(SSU1dot, ASU1dot);
        BU2dot = lyapunov_symmetric(SSU2dot, ASU2dot);
        BU3dot = lyapunov_symmetric(SSU3dot, ASU3dot);
        
        
        Hess.U1 = (ehess.U1 - eta.U1*BU1 - X.U1*BU1dot - 2*rgrad.U1*symm(eta_G1*X.G1'))/X.G1G1t;
        Hess.U2 = (ehess.U2 - eta.U2*BU2 - X.U2*BU2dot - 2*rgrad.U2*symm(eta_G2*X.G2'))/X.G2G2t;
        Hess.U3 = (ehess.U3 - eta.U3*BU3 - X.U3*BU3dot - 2*rgrad.U3*symm(eta_G3*X.G3'))/X.G3G3t;
        Hess.G = ehess.G;
        
        
        
        % BM: we need a correction factor for the non-constant metric
        % The correction factor owes itself to the Koszul formula.
        % This is the Riemannian connection in the Euclidean space with the
        % scaled metric.
        
        
        Hess.U1 = Hess.U1 + (eta.U1*symm(rgrad_G1*X.G1') + rgrad.U1*symm(eta_G1*X.G1'))/X.G1G1t;
        Hess.U2 = Hess.U2 + (eta.U2*symm(rgrad_G2*X.G2') + rgrad.U2*symm(eta_G2*X.G2'))/X.G2G2t;
        Hess.U3 = Hess.U3 + (eta.U3*symm(rgrad_G3*X.G3') + rgrad.U3*symm(eta_G3*X.G3'))/X.G3G3t;
        Hess.G = Hess.G  - permute(reshape(symm(rgrad.U1'*eta.U1)*X.G1,r1,r2,r3), [1 2 3]) ...
            - permute(reshape(symm(rgrad.U2'*eta.U2)*X.G2,r2,r1,r3), [2 1 3]) ...
            - permute(reshape(symm(rgrad.U3'*eta.U3)*X.G3,r3,r1,r2), [2 3 1]);
        
        % The Riemannian connection on the quotient space is the
        % projection on the tangent space of the total space and then onto the horizontal
        % space. This is accomplished with the following operation.
        
        Hess = M.proj(X, Hess);
        
        
    end
    
    
    
    
    M.proj = @projection;
    function etaproj = projection(X, eta)
        X = prepare(X); % Reuse already computed terms
        
        % First, projection onto tangent space of total space
        SSU1 = X.G1G1t;
        ASU1 = 2*symm(X.G1G1t*(X.U1'*eta.U1)*X.G1G1t);
        BU1 = lyapunov_symmetric(SSU1, ASU1);
        eta.U1 = eta.U1 - X.U1*(BU1/X.G1G1t);
        
        SSU2 = X.G2G2t;
        ASU2 = 2*symm(X.G2G2t*(X.U2'*eta.U2)*X.G2G2t);
        BU2 = lyapunov_symmetric(SSU2, ASU2);
        eta.U2 = eta.U2 - X.U2*(BU2/X.G2G2t);
        
        SSU3 = X.G3G3t;
        ASU3 = 2*symm(X.G3G3t*(X.U3'*eta.U3)*X.G3G3t);
        BU3 = lyapunov_symmetric(SSU3, ASU3);
        eta.U3 = eta.U3 - X.U3*(BU3/X.G3G3t);

        eta_G1 = reshape(eta.G, r1, r2*r3); 
        eta_G2 = reshape(permute(eta.G, [2 1 3]), r2, r1*r3); 
        eta_G3 = reshape(permute(eta.G, [3 1 2]), r3, r1*r2);
        
        
        % Project onto the horizontal space.
        PU1 = skew((X.U1'*eta.U1)*X.G1G1t) + skew(X.G1*eta_G1');
        PU2 = skew((X.U2'*eta.U2)*X.G2G2t) + skew(X.G2*eta_G2');
        PU3 = skew((X.U3'*eta.U3)*X.G3G3t) + skew(X.G3*eta_G3');
        
        % Calculate Omega1, Omega2, Omega3 that are required in finding the
        % horizontal component. 
        % We use the Matlab's pcg function to solve the system efficiently.
        % We exploit the structure by designing a good preconditioner as well.
        % The preconditioner takes the block positive definite part of the
        % linear system.
        
        % Options for PCG
        tol_omegax_pcg = 1e-6; % BM: standard tolerance as suggested in PCG.
        max_iterations_pcg = 15;% BM: fix this to 15 for simulations. In practice, it requires 7 to 10 iterations.
        
        % Preconditioner for PCG
        M1 = kron(speyer1,SSU1) + kron(SSU1, speyer1);
        M2 = kron(speyer2,SSU2) + kron(SSU2, speyer2);
        M3 = kron(speyer3,SSU3) + kron(SSU3, speyer3);
        
        Mprecon_pcg = sparse(zeros(r1^2 + r2^2 + r3^2));
        Mprecon_pcg(1 : r1^2, 1 : r1^2 ) = M1;
        Mprecon_pcg(1 + r1^2 : r1^2 + r2^2, 1 + r1^2 : r1^2 + r2^2) = M2;
        Mprecon_pcg(1 + r1^2 + r2^2 : end, 1 + r1^2 + r2^2 : end) = M3;
        
        % Call PCG
        [Omegaxsol, unused] = pcg(@compute_residual, [PU1(:); PU2(:); PU3(:)],  tol_omegax_pcg, max_iterations_pcg, Mprecon_pcg); %#ok<ASGLU>
        
        Omega1 = reshape(Omegaxsol(1:r1^2), r1, r1);
        Omega2 = reshape(Omegaxsol(1 + r1^2 : r1^2 + r2^2), r2, r2);
        Omega3 = reshape(Omegaxsol(1 + r1^2 + r2^2 : end), r3, r3);
            
        function AOmegax = compute_residual(Omegax)
            Omegax1 = reshape(Omegax(1:r1^2), r1, r1);
            Omegax2 = reshape(Omegax(1 + r1^2 : r1^2 + r2^2), r2, r2);
            Omegax3 = reshape(Omegax(1 + r1^2 + r2^2 : end), r3, r3);
            
            OffsetU1 = X.G1*((kron(speyer3,Omegax2) + kron(Omegax3, speyer2))*X.G1');
            OffsetU2 = X.G2*((kron(speyer3,Omegax1) + kron(Omegax3, speyer1))*X.G2');
            OffsetU3 = X.G3*((kron(speyer2,Omegax1) + kron(Omegax2, speyer1))*X.G3');
            
            residual1 = Omegax1*SSU1 + SSU1*Omegax1 - OffsetU1;
            residual2 = Omegax2*SSU2 + SSU2*Omegax2 - OffsetU2;
            residual3 = Omegax3*SSU3 + SSU3*Omegax3 - OffsetU3;
            
            AOmegax = [residual1(:); residual2(:); residual3(:)];
        end
        
        
        % Calculate projection along U1, U2, and U3
        etaproj.U1 = eta.U1 - (X.U1*Omega1);
        etaproj.U2 = eta.U2 - (X.U2*Omega2);
        etaproj.U3 = eta.U3 - (X.U3*Omega3);
        
        % Calculate projection along G 
        GOmega1 = reshape(Omega1*X.G1, r1, r2, r3);
        GOmega2 = permute(reshape(Omega2*X.G2, r2, r1, r3), [2 1 3]);
        GOmega3 = permute(reshape(Omega3*X.G3, r3, r1, r2), [2 3 1]); 
        etaproj.G = eta.G -(-(GOmega1+GOmega2+GOmega3));
        
    end
    
    
    
    M.tangent = M.proj;
    M.tangent2ambient = @(X, eta) eta;
    
    M.retr = @retraction;
    function Y = retraction(X, eta, t)
        if nargin < 3
            t = 1.0;
        end
        
        Y.G = (X.G + t*eta.G);
        Y.U1 = uf((X.U1 + t*eta.U1)); % U factor of Polar factorization
        Y.U2 = uf((X.U2 + t*eta.U2));
        Y.U3 = uf((X.U3 + t*eta.U3));
        
        Y = prepare(Y);
    end
    
    
    M.hash = @(X) ['z' hashmd5([sum(X.U1(:)) ; sum(X.U2(:)); sum(X.U3(:)); sum(X.G(:)) ])]; % Efficient, suggested by Bart Vandereycken.
    % M.hash = @(X) ['z' hashmd5([X.U1(:); X.U2(:); X.U3(:); X.G(:)])];
    
    M.rand = @random;
    function X = random()
        %         % Random generator on the total space
        %         % Factors U1, U2, and U3 are on Stiefel manifolds, hence we reuse
        %         % their random generator.
        %         stiefell = stiefelfactory(n1, r1);
        %         stiefelm = stiefelfactory(n2, r2);
        %         stiefeln = stiefelfactory(n3, r3);
        %
        %         X.U1 = stiefell.rand();
        %         X.U2 = stiefelm.rand();
        %         X.U3 = stiefeln.rand();
        %
        %         % Random initialization: generalization of randn(r1, r1 = r2) in the
        %         % matrix case.
        %         X.G = randn(r1,r2,r3);
        
        
        %  Random generator on the fixed-rank space from a uniform distribution on [0, 1].
        [U1, R1] = qr(rand(n1, r1), 0);
        [U2, R2] = qr(rand(n2, r2), 0);
        [U3, R3] = qr(rand(n3, r3), 0);
        C  = rand(r1, r2, r3);
        
        C1 = reshape(C, r1, r2*r3);
        CR1 = reshape(R1*C1, r1, r2, r3); % Multiplication by R1
        
        C2 = reshape(permute(CR1, [2 1 3]), r2, r1*r3);
        CR1R2 = permute(reshape(R2*C2, r2, r1, r3), [2 1 3]); % Multiplication by R2
        
        C3 = reshape(permute(CR1R2, [3 1 2]), r3, r1*r2);
        CR1R2R3 = permute(reshape(R3*C3, r3, r1, r2), [2 3 1]); % Multiplication by R3
        
        X.U1 = U1;
        X.U2 = U2;
        X.U3 = U3;
        X.G = CR1R2R3;
    
        
        % Compute some terms that are used subsequently.
        X = prepare(X);
        
    end
    
    M.randvec = @randomvec;
    function eta = randomvec(X)
        % A random vector on the horizontal space
        eta.U1 = randn(n1, r1);
        eta.U2 = randn(n2, r2);
        eta.U3 = randn(n3, r3);
        eta.G = randn(r1, r2, r3);
        eta = projection(X, eta);
        nrm = M.norm(X, eta);
        eta.U1 = eta.U1 / nrm;
        eta.U2 = eta.U2 / nrm;
        eta.U3 = eta.U3 / nrm;
        eta.G = eta.G / nrm;
    end
    
    M.lincomb = @lincomb;
    
    M.zerovec = @(X) struct('U1', zeros(n1, r1), 'U2', zeros(n2, r2), ...
                            'U3', zeros(n3, r3), 'G',  zeros(r1, r2, r3));
    
    M.transp = @transp;
    function v = transp(x1, x2, d) %#ok<INUSL>
        v = projection(x2, d);
    end
    
    % vec and mat are not isometries, because of the scaled metric.
    M.vec = @(X, U1) [U1.U1(:); U1.U2(:); U1.U3(:); U1.G(:)];
    M.mat = @(X, u) struct ...
        ('U1', reshape(u(1  : n1*r1), n1, r1), ...
        'U2', reshape(u(n1*r1 + 1 : n1*r1 + n2*r2), n2, r2), ...
        'U3', reshape(u(n1*r1 + n2*r2 + 1 : n1*r1 + n2*r2 + n3*r3), n3, r3), ...
        'G', reshape(u(n1*r1 + n2*r2 + n3*r3 + 1 : end), r1, r2, r3));
    M.vecmatareisometries = @() false;
    
end

% Linear combination of tangent vectors
function d = lincomb(X, a1, d1, a2, d2) %#ok<INUSL>
    
    if nargin == 3
        d.U1 = a1*d1.U1;
        d.U2 = a1*d1.U2;
        d.U3 = a1*d1.U3;
        d.G = a1*d1.G;
    elseif nargin == 5
        d.U1 = a1*d1.U1 + a2*d2.U1;
        d.U2 = a1*d1.U2 + a2*d2.U2;
        d.U3 = a1*d1.U3 + a2*d2.U3;
        d.G = a1*d1.G + a2*d2.G;
    else
        error('Bad use of fixedrankfactory_tucker_preconditioned.lincomb.');
    end
    
end

% U factor of Polar factorization of a tall matrix A.
function U = uf(A)
    [L, unused, R] = svd(A, 0); %#ok
    U = L*R';
end

function A = symm(Z)
    A = .5*(Z+Z');
end

function A = skew(Z)
    A = .5*(Z-Z');
end
