function grad = getGradient(problem, x, storedb, key)
% Computes the gradient of the cost function at x.
%
% function grad = getGradient(problem, x)
% function grad = getGradient(problem, x, storedb)
% function grad = getGradient(problem, x, storedb, key)
%
% Returns the gradient at x of the cost function described in the problem
% structure.
%
% storedb is a StoreDB object, key is the StoreDB key to point x.
%
% See also: getDirectionalDerivative canGetGradient

% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, Dec. 30, 2012.
% Contributors: 
% Change log: 
%
%   April 3, 2015 (NB):
%       Works with the new StoreDB class system.
%
%  June 28, 2016 (NB):
%       Works with getPartialGradient.
%
%   Nov. 1, 2016 (NB):
%       Added support for gradient from directional derivatives.
%       Last resort is call to getApproxGradient instead of an exception.
%
%   Sep. 6, 2018 (NB):
%       The gradient is now cached by default. This is made practical by
%       the new storedb 'remove' functionalities that keep the number of
%       cached points down to a minimum. If the gradient is obtained via
%       costgrad, the cost is also cached.
%
%   Feb. 10, 2020 (NB):
%       Allowing M.egrad2rgrad to take (storedb, key) as extra inputs.

    % Allow omission of the key, and even of storedb.
    if ~exist('key', 'var')
        if ~exist('storedb', 'var')
            storedb = StoreDB();
        end
        key = storedb.getNewKey();
    end

    % Contrary to most similar functions, here, we get the store by
    % default. This is for the caching functionality described below.
    store = storedb.getWithShared(key);
    store_is_stale = false;

    % If the gradient has been computed before at this point (and its
    % memory is still in storedb), then we just look up the value.
    force_grad_caching = true;
    if force_grad_caching && isfield(store, 'grad__')
        grad = store.grad__;
        return;
    end
    
    % We don't normally compute the cost value, but if we get it as a side
    % result, then we may as well take note of it for caching.
    cost_computed = false;
    
    
    if isfield(problem, 'grad')
    %% Compute the gradient using grad.
    
        % Check whether this function wants to deal with storedb or not.
        switch nargin(problem.grad)
            case 1
                grad = problem.grad(x);
            case 2
                [grad, store] = problem.grad(x, store);
            case 3
                % Pass along the whole storedb (by reference), with key.
                grad = problem.grad(x, storedb, key);
                % The store structure in storedb might have been modified
                % (since it is passed by reference), so before caching
                % we'll have to update (see below).
                store_is_stale = true;
            otherwise
                up = MException('manopt:getGradient:badgrad', ...
                    'grad should accept 1, 2 or 3 inputs.');
                throw(up);
        end
    
    elseif isfield(problem, 'costgrad')
    %% Compute the gradient using costgrad.
    
        % Check whether this function wants to deal with storedb or not.
        switch nargin(problem.costgrad)
            case 1
                [cost, grad] = problem.costgrad(x);
            case 2
                [cost, grad, store] = problem.costgrad(x, store);
            case 3
                % Pass along the whole storedb (by reference), with key.
                [cost, grad] = problem.costgrad(x, storedb, key);
                store_is_stale = true;
            otherwise
                up = MException('manopt:getGradient:badcostgrad', ...
                    'costgrad should accept 1, 2 or 3 inputs.');
                throw(up);
        end
        
        cost_computed = true;
    
    elseif canGetEuclideanGradient(problem)
    %% Compute the Riemannian gradient using the Euclidean gradient.
        
        egrad = getEuclideanGradient(problem, x, storedb, key);
        % Convert to the Riemannian gradient
        switch nargin(problem.M.egrad2rgrad)
            case 2
                grad = problem.M.egrad2rgrad(x, egrad);
            case 4
                grad = problem.M.egrad2rgrad(x, egrad, storedb, key);
            otherwise
                up = MException('manopt:getGradient:egrad2rgrad', ...
                    'egrad2rgrad should accept 2 or 4 inputs.');
                throw(up);
        end
        store_is_stale = true;
    
    elseif canGetPartialGradient(problem)
    %% Compute the gradient using a full partial gradient.
        
        d = problem.ncostterms;
        grad = getPartialGradient(problem, x, 1:d, storedb, key);
        store_is_stale = true;
        
    elseif canGetDirectionalDerivative(problem)
    %% Compute gradient based on directional derivatives; expensive!
    
        B = tangentorthobasis(problem.M, x);
        df = zeros(size(B));
        for k = 1 : numel(B)
            df(k) = getDirectionalDerivative(problem, x, B{k}, storedb, key);
        end
        grad = lincomb(problem.M, x, B, df);
        store_is_stale = true;

    else
    %% Attempt the computation of an approximation of the gradient.
        
        grad = getApproxGradient(problem, x, storedb, key);
        store_is_stale = true;
        
    end

    % If we are not sure that the store structure is up to date, update.
    if store_is_stale
        store = storedb.getWithShared(key);
    end
    
    % Cache here.
    if force_grad_caching
        store.grad__ = grad; 
    end
    % If we got the gradient via costgrad, then the cost has also been
    % computed and we can cache it.
    if cost_computed
        store.cost__ = cost;
    end

    storedb.setWithShared(store, key);
    
end
