function preconfun = preconhessiansolve(problem, options)
% Preconditioner based on the inverse Hessian, by solving linear systems.
%
% function preconfun = preconhessiansolve(problem)
% function preconfun = preconhessiansolve(problem, options)
%
% Input:
%
% A Manopt problem structure (already containing the manifold and enough
% information to compute the Hessian of the cost) and an options structure
% (optional, currently ignored). Notice that if the Hessian is not positive
% definite, then its inverse is not positive definite either and this
% preconditioner is not suitable.
%
% If the Hessian cannot be computed on 'problem', a warning is issued. An
% approximation of the Hessian will be used instead, and the present
% preconditioner will attempt to invert that (although it may not be a
% linear operator). If no approximate Hessian is provided either, a generic
% approximation is used. Behavior is unspecified.
%
% Output:
% 
% Returns a function handle, encapsulating a generic preconditioner of the
% Hessian based on solving linear systems of the form:
%   Hessian(x)[preconfun(x, xdot)] = xdot,
% where x is the point on the manifold, xdot is the input to the
% preconditioner (a tangent vector) and preconfun(x, xdot) is returned
% (also a tangent vector). The solve may be approximate.
% 
% The returned preconfun has this calling pattern:
% 
%   function precxdot = preconfun(x, xdot)
%   function precxdot = preconfun(x, xdot, storedb)
%   function precxdot = preconfun(x, xdot, storedb, key)
% 
% x is a point on the manifold problem.M, xdot is a tangent vector to that
% manifold at x, storedb is a StoreDB object, and key is the StoreDB key to
% point x.
%
% Usage:
%
% Typically, the user will set problem.M and other fields to define the
% cost, the gradient and the Hessian (typically, problem.cost, problem.grad
% and problem.hess, or problem.egrad and problem.ehess). Then, to use this
% generic purpose Hessian preconditioner:
%
%   problem.precon = preconhessiansolve(problem, options);
%
% Passing that problem structure to the conjugategradients solver
% (which uses preconditioning) configured in steepest descent mode results
% in a type of Riemannian Newton method.
%
% See also: conjugategradients

% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, April 9, 2015.
% Contributors: 
% Change log: 

    % Check availability of the Hessian, or at least of an approximation.
    if ~canGetHessian(problem) && ~canGetApproxHessian(problem)
        % Note: we do not give a warning if an approximate Hessian is
        % explicitly given in the problem description, as in that case the
        % user seems to be aware of the issue.
        warning('manopt:getHessian:approx', ...
               ['No Hessian provided. Using an FD approximation instead.\n' ...
                'To disable this warning: warning(''off'', ''manopt:getHessian:approx'')']);
        problem.approxhess = approxhessianFD(problem);
    end

    % Set local defaults here, and merge with user options, if any.
    localdefaults = struct();
    if ~exist('options', 'var') || isempty(options)
        options = struct();
    end
    options = mergeOptions(localdefaults, options);

    % Build and return the function handle here. This extra construct via
    % funhandle makes it possible to make storedb and key optional.
    preconfun = @funhandle;
    function precxdot = funhandle(x, xdot, storedb, key)
        % Allow omission of the key, and even of storedb.
        if ~exist('key', 'var')
            if ~exist('storedb', 'var')
                storedb = StoreDB();
            end
            key = storedb.getNewKey();
        end 
        precxdot = hessiansolvehelper(options, problem, x, xdot, ...
                                      storedb, key);
    end
    
end


function precxdot = hessiansolvehelper(options, problem, x, xdot, storedb, key)
% This function does the actual work.
    
    % Exclude the case where xdot is zero
    norm_xdot = problem.M.norm(x, xdot);
    if norm_xdot < eps
        precxdot = problem.M.zerovec(x);
        return;
    end
    
    % Get a shorthand for the Hessian of the cost on M at x.
    hessian = @(u) getHessian(problem, x, u, storedb, key);
    
    % Setup an optimization problem on the tangent space to problem.M at x.
    M = problem.M;
    tgtspace = tangentspacefactory(M, x);
    prblm.M = tgtspace;
    prblm.cost = @cost;
    prblm.grad = @grad;
    prblm.hess = @(u, udot) 2*hessian(hessian(udot))/norm_xdot;
    
    function [f, store] = cost(u, store)
        if ~isfield(store, 'residue')
            Hu = hessian(u);
            store.residue = M.lincomb(x, 1, Hu, -1, xdot);
        end
        f = M.norm(x, store.residue).^2 / norm_xdot;
    end
    function [g, store] = grad(u, store)
        if ~isfield(store, 'residue')
            Hu = hessian(u);
            store.residue = M.lincomb(x, 1, Hu, -1, xdot);
        end
        g = 2 * hessian(store.residue) / norm_xdot;
    end
    
    % checkgradient(prblm); pause;
    % checkhessian(prblm); pause;
    
    localdefaults.solver = @trustregions;
    localdefaults.verbosity = 0;
    % Merge local defaults with user options, if any.
    if ~exist('options', 'var') || isempty(options)
        options = struct();
    end
    options = mergeOptions(localdefaults, options);
    
    % Solve the linear system by solving the optimization problem.
    precxdot = manoptsolve(prblm, M.zerovec(), options);
    
end
