function low_rank_tensor_completion()
% Given partial observation of a low rank tensor, attempts to complete it.
%
% function low_rank_tensor_completion()
%
% This example demonstrates how to use the geometry factory for the
% quotient manifold of fixed-rank tensors, 
% fixedrankfactory_tucker_preconditioned.
%
% This geometry is described in the technical report
% "Riemannian preconditioning for tensor completion"
% Hiroyuki Kasai and Bamdev Mishra, arXiv:1506.02159, 2015.
%
% This can be a starting point for many optimization problems of the form:
%
% minimize f(X) such that rank(X) = [r1 r2 r3], size(X) = [n1, n2, n3].
%
% Input:  None. This example file generates random data.
% 
% Output: None.
%
% Please cite the Manopt paper as well as the research paper:
%     @Techreport{kasai2015,
%       Title   = {{R}iemannian preconditioning for tensor completion},
%       Author  = {Kasai, H. and Mishra, B.},
%       Journal = {Arxiv preprint arXiv:1506.02159},
%       Year    = {2015}
%     }

% This file is part of Manopt and is copyrighted. See the license file.
% 
% Main authors: Hiroyuki Kasai and Bamdev Mishra, June 16, 2015.
% Contributors:
% 
% Change log:
%
%   Xiaowen Jiang Aug. 20, 2021
%       Added AD to compute the egrad and the ehess  
    

    % Random data generation with pseudo-random numbers from a 
    % uniform distribution on [0, 1].
    % First, choose the size of the problem.
    % We will complete a tensor of size n1-by-n2-by-n3 of rank (r1, r2, r3):  
    n1 = 70;
    n2 = 60;
    n3 = 50;
    r1 = 3;
    r2 = 4;
    r3 = 5;
    tensor_dims = [n1 n2 n3];
    core_dims = [r1 r2 r3];
    total_entries = n1*n2*n3;
    
    % Generate a random tensor A of size n1-by-n2-by-n3 of rank (r1, r2, r3).
    [U1,R1] = qr(rand(n1, r1), 0);
    [U2,R2] = qr(rand(n2, r2), 0);
    [U3,R3] = qr(rand(n3, r3), 0);

    Z.U1 = R1;
    Z.U2 = R2;
    Z.U3 = R3;   
    Z.G = rand( core_dims );
    Core = tucker2multiarray(Z); % Converts tucker format tensor to full tensor.

    Y.U1 = U1;
    Y.U2 = U2;
    Y.U3 = U3;
    Y.G = Core;
    A = tucker2multiarray(Y);       
    
    % Generate a random mask P for observed entries: P(i, j, k) = 1 if the entry
    % (i, j, k) of A is observed, and 0 otherwise.    
    % Observation ratio
    fraction = 0.1; % Fraction of known entries.
    nr = round(fraction * total_entries);
    ind = randperm(total_entries);
    ind = ind(1 : nr);
    P = false(tensor_dims);
    P(ind) = true;    
    % Hence, we know the nonzero entries in PA:
    PA = P.*A;  
    

    
    
    % Pick the manifold of tensors of size n1-by-n2-by-n3 of rank (r1, r2, r3).
    problem.M = fixedrankfactory_tucker_preconditioned(tensor_dims, core_dims);
    
    
    
    
    % Define the problem cost function. The input X is a structure with
    % fields U1, U2, U3, G representing a rank (r1,r2,r3) tensor.
    % f(X) = 1/2 * || P.*(X - A) ||^2
    problem.cost = @cost;
    function f = cost(X)
        Xmultiarray = tucker2multiarray(X);
        Diffmultiarray = P.*Xmultiarray - PA;
        Diffmultiarray_flat = reshape(Diffmultiarray, n1, n2*n3);
        f = .5*norm(Diffmultiarray_flat , 'fro')^2;
    end


    
    
    % Define the Euclidean gradient of the cost function, that is, the
    % gradient of f(X) seen as a standard function of X.
    % nabla f(X) = P.*(X-A)
    % We only need to give the Euclidean gradient. Manopt converts it
    % internally to the Riemannian counterpart.
    problem.egrad =  @egrad;
    function [g] = egrad(X)
        Xmultiarray = tucker2multiarray(X);
        Smultiarray = P.*Xmultiarray - PA;     

        % BM: computation of S, S1, S2, and S3
        S1multiarray = reshape(Smultiarray, [n1, n2*n3]);
        S2multiarray = reshape(permute(Smultiarray, [2 1 3]),[n2, n1*n3]);
        S3multiarray = reshape(permute(Smultiarray, [3 1 2]),[n3, n1*n2]);

        g.U1 = double(S1multiarray) * kron(X.U3, X.U2) * reshape(X.G, r1, r2*r3)';
        g.U2 = double(S2multiarray) * kron(X.U3, X.U1) * reshape(permute(X.G, [2 1 3]), r2, r1*r3)';
        g.U3 = double(S3multiarray) * kron(X.U2, X.U1) * reshape(permute(X.G, [3 1 2]), r3, r1*r2)';
        g.G = reshape(X.U1' * reshape(double(Smultiarray),n1,n2*n3) * kron(X.U3', X.U2')', r1, r2, r3);  
    end
    
    
    
    
    
    % Define the Euclidean Hessian of the cost at X, along eta, where eta is
    % represented as a tangent vector: a structure with fields U1, U2, U3, G.
    % This is the directional derivative of nabla f(X) at X along Xdot:
    % nabla^2 f(X)[Xdot] = P.*Xdot
    % We only need to give the Euclidean Hessian. Manopt converts it
    % internally to the Riemannian counterpart.
    problem.ehess = @ehess;
    function [Hess] = ehess(X, eta)

        % Computing S, and its unfolding matrices, S1, S2, and S3.
        Xmultiarray = tucker2multiarray(X);
        S = P.*Xmultiarray - PA;     
        S1 = reshape(S, [n1, n2*n3]);
        S2 = reshape(permute(S, [2 1 3]),[n2, n1*n3]);
        S3 = reshape(permute(S, [3 1 2]),[n3, n1*n2]);            

        % Computing Sdot, S1dot, S2dot and S3dot.
        XG = X.G;
        etaG = eta.G;
        G1 = zeros(4*size(X.G));
        G1(1:r1, 1:r2, 1:r3) = XG;
        G1(r1 + 1 : 2*r1, r2 + 1 : 2*r2, r3 + 1 : 2*r3) = XG;
        G1(2*r1 + 1 : 3*r1, 2*r2 + 1 : 3*r2, 2*r3 + 1 : 3*r3) = XG;
        G1(3*r1 + 1 : 4*r1, 3*r2 + 1 : 4*r2, 3*r3 + 1 : 4*r3) = etaG;      
             
        X1.G = G1;
        X1.U1 = [eta.U1 X.U1 X.U1 X.U1];
        X1.U2 = [X.U2 eta.U2 X.U2 X.U2];
        X1.U3 = [X.U3 X.U3 eta.U3 X.U3];
        
        X1multiarray = tucker2multiarray(X1);
        Sdot = P.*X1multiarray;
        S1dot = reshape(Sdot, [n1, n2*n3]);
        S2dot = reshape(permute(Sdot, [2 1 3]),[n2, n1*n3]);
        S3dot = reshape(permute(Sdot, [3 1 2]),[n3, n1*n2]);
        
        % Computing unfolding matrices of X.G and eta.G.
        X_G1 = reshape(double(X.G),r1, r2*r3);
        X_G2 = reshape(permute(double(X.G),[2 1 3]),r2, r1*r3);
        X_G3 = reshape(permute(double(X.G),[3 1 2]),r3, r1*r2);
        eta_G1 = reshape(double(eta.G),r1, r2*r3);
        eta_G2 = reshape(permute(double(eta.G),[2 1 3]),r2, r1*r3);
        eta_G3 = reshape(permute(double(eta.G),[3 1 2]),r3, r1*r2);             

        % Computing Hessians for U1, U2 and U3.
        T1 = double(S1dot) * (kron(X.U3,X.U2)*X_G1') ...
            + double(S1) * (kron(eta.U3,X.U2)*X_G1' ...
            + kron(X.U3,eta.U2)*X_G1' + kron(X.U3,X.U2)*eta_G1');
        
        T2 = double(S2dot) * (kron(X.U3,X.U1)*X_G2') ...
            + double(S2) * (kron(eta.U3,X.U1)*X_G2' ...
            + kron(X.U3,eta.U1)*X_G2' + kron(X.U3,X.U1)*eta_G2');

        T3 = double(S3dot) * (kron(X.U2,X.U1)*X_G3') ...
            + double(S3) * (kron(eta.U2,X.U1)*X_G3' ...
            + kron(X.U2,eta.U1)*X_G3' + kron(X.U2,X.U1)*eta_G3');
        
        Hess.U1 = T1;
        Hess.U2 = T2;
        Hess.U3 = T3;  
        
        % Computing Hessian for G
        N.U1 = X.U1';
        N.U2 = X.U2';
        N.U3 = X.U3';
        N.G = Sdot;
        M0array = tucker2multiarray(N);
        
        M1.U1 = eta.U1';
        M1.U2 = X.U2';
        M1.U3 = X.U3';
        M1.G = S;    
        M1array = tucker2multiarray(M1);
        
        M2.U1 = X.U1';
        M2.U2 = eta.U2';
        M2.U3 = X.U3';
        M2.G = S;    
        M2array = tucker2multiarray(M2); 
        
        M3.U1 = X.U1';
        M3.U2 = X.U2';
        M3.U3 = eta.U3';
        M3.G = S;    
        M3array = tucker2multiarray(M3);   
        
        Hess.G = M0array + M1array + M2array + M3array; 
    end
    
    % An alternative way to compute the egrad and the ehess is to use 
    % automatic differentiation provided in the deep learning toolbox (slower)
    % Notice that the function norm is not supported for AD so far.
    % Replace norm(...,'fro')^2 with cnormsqfro described in the file 
    % manoptADhelp.m. Also, operations between sparse matrices and dlarrays 
    % are not supported for now. convert PA and P into full arrays.
    % problem.cost = @cost_AD;
    % PA_full = full(PA);
    % P_full = full(P);
    %    function f = cost_AD(X)
    %        Xmultiarray = tucker2multiarray(X);
    %        Diffmultiarray = P_full.*Xmultiarray - PA_full;
    %        Diffmultiarray_flat = reshape(Diffmultiarray, n1, n2*n3);
    %        f = .5*cnormsqfro(Diffmultiarray_flat);
    %    end
    % call manoptAD to prepare AD for the problem structure
    % problem = manoptAD(problem);
 

    % Check consistency of the gradient and the Hessian. Useful if you
    % adapt this example for a new cost function and you would like to make
    % sure there is no mistake.
    %
    % Notice that the checkhessian test fails: the slope is not right. 
    % This is because the retraction is not second-order compatible with 
    % the Riemannian exponential on this manifold, making 
    % the checkhessian tool unusable. The Hessian is correct though. 
    % % warning('off', 'manopt:fixedrankfactory_tucker_preconditioned:exp');
    % % checkgradient(problem);
    % % drawnow;
    % % pause;
    % % checkhessian(problem);
    % % drawnow;
    % % pause;
    

    
    % options
    options.maxiter = 200;
    options.maxinner = 30;
    options.maxtime = inf;
    options.tolgradnorm = 1e-5;     


    
    
    % Minimize the cost function using Riemannian trust-regions
    Xtr = trustregions(problem, [], options);

    
    
    % The reconstructed tensor is X, represented as a structure with fields
    % U1, U2, U3 and G.    
    Xtrmultiarray = tucker2multiarray(Xtr);
    fprintf('||X-A||_F = %g\n', norm(reshape(Xtrmultiarray - A, [n1 n2*n3]), 'fro'));   
    
   
    
    
    % Alternatively, we could decide to use a solver such as steepestdescent (SD) 
    % or conjugategradient (CG). These solvers need to solve a
    % line-search problem at each iteration. Standard line searches in
    % Manopt have generic purpose systems to do this. But for the problem
    % at hand, we could exploit the least-squares structure to compute an
    % approximate stepsize for the line-search problem. The approximation
    % is obtained by linearizing the nonlinear manifold locally and further
    % approximating it with a degree 2 polynomial approximation.
    % The specific derivation is in the paper referenced above.
    
    problem.linesearch = @linesearch_helper;
    function tmin = linesearch_helper(X, eta)
        
        % term0
        Xmultiarray = tucker2multiarray(X);
        residual_mat = P.*Xmultiarray - PA;     
        residual_vec = residual_mat(:);
        term0 = residual_vec;
        
        % term1
        XG = X.G;
        etaG = eta.G;        
        G1 = zeros(4*size(X.G));
        G1(1:r1, 1:r2, 1:r3) = XG;
        G1(r1 + 1 : 2*r1, r2 + 1 : 2*r2, r3 + 1 : 2*r3) = XG;
        G1(2*r1 + 1 : 3*r1, 2*r2 + 1 : 3*r2, 2*r3 + 1 : 3*r3) = XG;
        G1(3*r1 + 1 : 4*r1, 3*r2 + 1 : 4*r2, 3*r3 + 1 : 4*r3) = etaG;  

        X1.U1 = [eta.U1 X.U1 X.U1 X.U1];
        X1.U2 = [X.U2 eta.U2 X.U2 X.U2];
        X1.U3 = [X.U3 X.U3 eta.U3 X.U3];
        X1.G = G1;
        
        X1multiarray = tucker2multiarray(X1);
        term1_mat = P.*X1multiarray;    
        term1 = term1_mat(:);
        
        % tmin is the solution to the problem argmin a2*t^2 + a1*t, where
        % the coefficients a1 and a2 are shown below.
        a2 = (term1'*term1);
        a1 = 2*(term1'*term0);
        tmin = - 0.5*(a1 / a2);
        
    end    

    % Notice that for this solver, the Hessian is not needed.
    [Xcg, costcg, infocg] = conjugategradient(problem, [], options);
    
    fprintf('Take a look at the options that CG used:\n');
    disp(options);
    fprintf('And see how many trials were made at each line search call:\n');
    info_ls = [infocg.linesearch];
    disp([info_ls.costevals]); 
    
    
     
    fprintf('Try it again without the linesearch helper.\n');
    
    % Remove the linesearch helper from the problem structure.
    problem = rmfield(problem, 'linesearch');
    
    [Xcg, xcost, info, options] = conjugategradient(problem, []); %#ok<ASGLU>
    
    fprintf('Take a look at the options that CG used:\n');
    disp(options);
    fprintf('And see how many trials were made at each line search call:\n');
    info_ls = [info.linesearch];
    disp([info_ls.costevals]);
    
    
    
end
