function [x, cost, info, options] = barzilaiborwein(problem, x, options)
% Riemannian Barzilai-Borwein solver with non-monotone line-search.
%
% function [x, cost, info, options] = barzilaiborwein(problem)
% function [x, cost, info, options] = barzilaiborwein(problem, x0)
% function [x, cost, info, options] = barzilaiborwein(problem, x0, options)
% function [x, cost, info, options] = barzilaiborwein(problem, [], options)
%
% Apply the Barzilai-Borwein minimization algorithm to the problem defined
% in the problem structure, starting at x0 if it is provided (otherwise, at
% a random point on the manifold). To specify options whilst not specifying
% an initial guess, give x0 as [] (the empty matrix).
%
% The algorithm uses its own special non-monotone line-search strategy.
% Therefore, no line-search algorithm should be specified in the problem
% structure or in the options structure.
%
% In most of the examples bundled with the toolbox (see link below), the
% solver can be replaced by the present one if need be.
%
% The outputs x and cost are the last reached point on the manifold and its
% cost. This is not necessarily the best point generated since the method
% is not monotone. The struct-array info contains information about the
% iterations:
%   iter : the iteration number (0 for the initial guess)
%   cost : cost value
%   time : elapsed time in seconds
%   gradnorm : Riemannian norm of the gradient
%   stepsize : norm of the last tangent vector retracted
%   linesearch : information logged by the line-search algorithm
%   And possibly additional information logged by options.statsfun.
% For example, type [info.gradnorm] to obtain a vector of the successive
% gradient norms reached.
%
% The options structure is used to overwrite the default values. All
% options have a default value and are hence optional. To force an option
% value, pass an options structure with a field options.optionname, where
% optionname is one of the following and the default value is indicated
% between parentheses:
%
%   tolgradnorm (1e-6)
%       The algorithm terminates if the norm of the gradient drops below this.
%   maxiter (1000)
%       The algorithm terminates if maxiter iterations have been executed.
%   maxtime (Inf)
%       The algorithm terminates if maxtime seconds elapsed.
%   minstepsize (1e-10)
%       The algorithm terminates if the linesearch returns a displacement
%       vector (to be retracted) smaller in norm than this value.
%   linesearch (@linesearch_hint)
%       This option should not be changed, as the present solver has its
%       own dedicated line-search strategy.
%   strategy ('direct')
%       The strategy used for the Barzilai-Borwein stepsize
%       'direct', compute the direct step <s_k,s_k>/<s_k,y_k>
%       'inverse', compute the inverse step <s_k,y_k>/<y_k,y_k>
%       'alternate', alternates between direct and inverse step
%   lambdamax (1e3)
%       The maximum stepsize allowed by the Barzilai-Borwein method
%   lambdamin (1e-3)
%       The minimum stepsize allowed by the Barzilai-Borwein method
%   lambda0 (1/10)
%       The initial stepsize of the Barzilai-Borwein method
%   ls_nmsteps (10)
%       The non-monotone line-search checks a sufficient decrease with respect
%       to the previous ls_nmsteps objective function values.
%   statsfun (none)
%       Function handle to a function that will be called after each
%       iteration to provide the opportunity to log additional statistics.
%       They will be returned in the info struct. See the generic Manopt
%       documentation about solvers for further information.
%   stopfun (none)
%       Function handle to a function that will be called at each iteration
%       to provide the opportunity to specify additional stopping criteria.
%       See the generic Manopt documentation about solvers for further
%       information.
%   verbosity (3)
%       Integer number used to tune the amount of output the algorithm
%       generates during execution (mostly as text in the command window).
%       The higher, the more output. 0 means silent.
%   storedepth (2)
%       Maximum number of different points x of the manifold for which a
%       store structure will be kept in memory in the storedb. If the
%       caching features of Manopt are not used, this is irrelevant. For
%       this algorithm, a store depth of 2 should always be sufficient.
%   
%
% The implementation of the Barzilai-Borwein method is based on the paper:
%
% B. Iannazzo, M. Porcelli, "The Riemannian Barzilai-Borwein method with 
% nonmonotone line-search and the matrix geometric mean computation",
% IMA Journal of Numerical Analysis, to appear, https://doi.org/10.1093/imanum/drx015.
%
% See also: steepestdescent conjugategradient trustregions

% This file is part of Manopt: www.manopt.org.
% Original author: Margherita Porcelli, May 31, 2017
% Contributors: Nicolas Boumal, Bruno Iannazzo
% Change log: 
%
%   Aug. 2, 2018 (NB):
%       Now using storedb.remove() to keep the cache lean.

    
    % Verify that the problem description is sufficient for the solver.
    if ~canGetCost(problem)
        warning('manopt:getCost', ...
                'No cost provided. The algorithm will likely abort.');  
    end
    if ~canGetGradient(problem) && ~canGetApproxGradient(problem)
        % Note: we do not give a warning if an approximate gradient is
        % explicitly given in the problem description, as in that case the
        % user seems to be aware of the issue.
        warning('manopt:getGradient:approx', ...
               ['No gradient provided. Using an FD approximation instead (slow).\n' ...
                'It may be necessary to increase options.tolgradnorm.\n' ...
                'To disable this warning: warning(''off'', ''manopt:getGradient:approx'')']);
        problem.approxgrad = approxgradientFD(problem);
    end

    % Ensure options exists as a structure
    if ~exist('options', 'var') || isempty(options)
        options = struct();
    end
    
    % Set local defaults here
    localdefaults.minstepsize = 1e-10;
    localdefaults.maxiter = 1000;
    localdefaults.tolgradnorm = 1e-6;

    % Upper and lower bound for the Barzilai-Borwein stepsize
    localdefaults.lambdamax = 1e3;
    localdefaults.lambdamin = 1e-3;
    % Initial Barzilai-Borwein stepsize
    localdefaults.lambda0 = 1/10;

    % Barzilai-Borwein strategy (direct, inverse or alternate)
    localdefaults.strategy = 'direct';

    % Line-search parameters
    % 1) Make sure the user didn't try to define a line search
    if canGetLinesearch(problem) || isfield(options, 'linesearch')
        error('manopt:BB:ls', ...
              ['The problem structure may not specify a line-search ' ...
               'hint for the BB solver,\nand the options structure ' ...
               'may not specify a line-search algorithm for BB.']);
    end
    % 2) Define the line-search parameters
    problem.linesearch = @(x, d, storedb, key) 1;
    options.linesearch = @linesearch_hint;
    % The Armijo sufficient decrease parameter
    localdefaults.ls_suff_decr = 1e-4;
    % The previous steps checked in the non-monotone line-search strategy
    localdefaults.ls_nmsteps = 10;
    
    
    % Merge global and local defaults, then merge w/ user options, if any.
    localdefaults = mergeOptions(getGlobalDefaults(), localdefaults);
    options = mergeOptions(localdefaults, options); 

    
    % Shorthands for some parameters
    strategy = options.strategy;
    lambdamax = options.lambdamax;
    lambdamin = options.lambdamin;
    lambda0 = options.lambda0;
    
    timetic = tic();
    
    
    % If no initial point x is given by the user, generate one at random.
    if ~exist('x', 'var') || isempty(x)
        x = problem.M.rand();
    end

    % Create a store database and get a key for the current x
    storedb = StoreDB(options.storedepth);
    key = storedb.getNewKey();
    
    % Compute objective-related quantities for x
    [cost, grad] = getCostGrad(problem, x, storedb, key);
    gradnorm = problem.M.norm(x, grad);

    % Some variables below need to store information about iterations. We
    % preallocate for a reasonable amount of intended iterations to avoid
    % memory re-allocations.
    mem_init_size = min(10000, options.maxiter+1);
    
    % Store the cost value
    f = zeros(mem_init_size, 1);
    f(1) = cost;
    
    % Iteration counter (at any point, iter is the number of fully executed
    % iterations so far)
    iter = 0;
    
    % Save stats in a struct array info, and preallocate.
    stats = savestats();
    info(1) = stats;
    info(mem_init_size).iter = [];
    
    if options.verbosity >= 2
        fprintf(' iter\t                cost val\t     grad. norm\n');
    end

    % Set the initial Barzilai-Borwein stepsize
    lambda = lambda0;

    % Start iterating until stopping criterion triggers
    while true

        % Display iteration information
        if options.verbosity >= 2
            fprintf('%5d\t%+.16e\t%.8e\n', iter, cost, gradnorm);
        end
        
        % Start timing this iteration
        timetic = tic();
        
        % Run standard stopping criterion checks
        [stop, reason] = stoppingcriterion(problem, x, options, ...
                                                             info, iter+1);
        
        % If none triggered, run specific stopping criterion check
        if ~stop && stats.stepsize < options.minstepsize
            stop = true;
            reason = sprintf(['Last stepsize smaller than minimum '  ...
                              'allowed; options.minstepsize = %g.'], ...
                              options.minstepsize);
        end
    
        if stop
            if options.verbosity >= 1
                fprintf([reason '\n']);
            end
            break;
        end

        % Pick the descent direction as minus the gradient (scaled)
        desc_dir = problem.M.lincomb(x, -lambda, grad);

        % Execute the nonmonotone line search
        k = iter + 1; 
        start = max(1, k - options.ls_nmsteps + 1);
        
        [stepsize, newx, newkey, lsstats] = ...
            options.linesearch(problem, x, desc_dir, max(f(start:k)), ...
                            -lambda * gradnorm^2, options, storedb, key);

        % Updates the value of lambda
        lambda = lambda * lsstats.alpha;

        % Compute the new cost-related quantities for newx
        [newcost, newgrad] = getCostGrad(problem, newx, storedb, newkey);
        newgradnorm = problem.M.norm(newx, newgrad);

        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        % BARZILAI-BORWEIN STRATEGY

        % Store the cost value
        f(iter+2) = newcost;
       
        % Transport the old gradient to newx
        grad_transp = problem.M.transp(x, newx, grad);

        % Compute the difference between grandients 
        Y = problem.M.lincomb(newx, 1, newgrad, -1, grad_transp);

        % Compute the transported step
        Stransp =  problem.M.lincomb(x, -lambda, grad_transp); 

        % Compute the new Barzilai-Borwein step following the strategy
        % direct strategy
        if strcmp(strategy, 'direct')
          num = problem.M.norm(newx, Stransp)^2; 
          den = problem.M.inner(newx, Stransp, Y);
          if den > 0
            lambda = min( lambdamax, max(lambdamin, num/den) );
          else
            lambda = lambdamax;
          end
        end

        % inverse strategy
        if strcmp(strategy, 'inverse')
          num = problem.M.inner(newx, Stransp, Y);
          den = problem.M.norm(newx, Y)^2;

          if num > 0  
            lambda = min( lambdamax, max(lambdamin, num/den) );
          else
            lambda = lambdamax;
          end
        end

        % alternate strategy
        if strcmp(strategy, 'alternate')
          num = problem.M.norm(newx, Stransp)^2; 
          den = problem.M.inner(newx, Stransp, Y);
          den2 = problem.M.norm(newx, Y)^2;
          if (den > 0)  
            if mod(iter,2)==0
                lambda = min( lambdamax, max(lambdamin, num/den) );
        else
                lambda = min( lambdamax, max(lambdamin, den/den2) );
            end
          else
            lambda = lambdamax;
          end
        end
        
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        
        % Update iterate info
        storedb.removefirstifdifferent(key, newkey);
        x = newx;
        key = newkey;
        cost = newcost;
        grad = newgrad;
        gradnorm = newgradnorm;

        % iter is the number of iterations we have accomplished.
        iter = iter + 1;

        % Make sure we don't use too much memory for the store database.
        storedb.purge();
        
        % Log statistics for freshly executed iteration
        stats = savestats();
        info(iter+1) = stats;
        
    end
    
    info = info(1:iter+1);

    if options.verbosity >= 1
        fprintf('Total time is %f [s] (excludes statsfun)\n', ...
                info(end).time);
    end
    
    
    
    % Routine in charge of collecting the current iteration stats
    function stats = savestats()
        stats.iter = iter;
        stats.cost = cost;
        stats.gradnorm = gradnorm;
        if iter == 0
            stats.stepsize = NaN;
            stats.time = toc(timetic);
            stats.linesearch = [];
        else
            stats.stepsize = stepsize;
            stats.time = info(iter).time + toc(timetic);
            stats.linesearch = lsstats;
        end
        stats = applyStatsfun(problem, x, storedb, key, options, stats);
    end
    
end
