/**
File:		MachineLearning/Optimization/Unconstrained/Linesearch/FgArmijoBracketingLineSearch.cpp

Author:		
Email:		
Site:       

Copyright (c) 2019 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgArmijoBracketingLineSearch.h>
//#include <MachineLearning/CommonUtil.h>
//#include <cmath>
//#include <math.h>
//#include <limits>
//#include <iomanip>

namespace NeuralEngine::MachineLearning
{
	template class ArmijoBracketingLineSearch<float>;
	template class ArmijoBracketingLineSearch<double>;

	template<typename Scalar>
	void ArmijoBracketingLineSearch<Scalar>::Linesearch(NonlinearObjectiveFunction<Scalar>& objFunc, Scalar& fx, af::array& x, af::array& grad, Scalar& step, const af::array& drt, const af::array& xp)
	{
		Scalar ftol = 1e-4;

		int max_linesearch = 20;

		// Check the value of step
		if (step <= 0) std::invalid_argument("'step' must be positive");

		// Save the function value at the current x
		const Scalar fx_init = fx;
		// Projection of gradient on the search direction
		const Scalar dg_init = matmulTN(grad, drt).scalar<Scalar>();
		// Make sure d points to a descent direction
		if (dg_init > 0)
			std::logic_error("the moving direction increases the objective function value");

		const Scalar dg_test = ftol * dg_init;

		// Upper and lower end of the current line search range
		Scalar step_lo = 0,
			step_hi = std::numeric_limits<Scalar>::infinity();

		for (int iter = 0; iter < max_linesearch; iter++)
		{
			// x_{k+1} = x_k + step * d_k
			x = xp + step * drt;
			// Evaluate this candidate
			fx = objFunc.Value(x);
			grad = objFunc.Gradient(x);

			if (fx > fx_init + step * dg_test)
			{
				step_hi = step;
			}
			else 
			{
				// Armijo condition is met
				//if (linesearch == LBFGS_LINESEARCH_BACKTRACKING_ARMIJO)
				break;

				//const Scalar dg = matmulTN(grad, drt).scalar<Scalar>();
				//if (dg < wolfe * dg_init)
				//{
				//	step_lo = step;
				//}
				//else {
				//	// Regular Wolfe condition is met
				//	if (linesearch == LBFGS_LINESEARCH_BACKTRACKING_WOLFE)
				//		break;

				//	if (dg > -wolfe * dg_init)
				//	{
				//		step_hi = step;
				//	}
				//	else {
				//		// Strong Wolfe condition is met
				//		break;
				//	}
				//}
			}

			LogAssert(step_lo < step_hi, "Lower bound equal or higher than upper bound.");

			/*if (iter >= max_linesearch)
				throw std::runtime_error("the line search routine reached the maximum number of iterations");

			if (step < min_step)
				throw std::runtime_error("the line search step became smaller than the minimum value allowed");

			if (step > max_step)
				throw std::runtime_error("the line search step became larger than the maximum value allowed");*/

			// continue search in mid of current search range
			step = std::isinf(step_hi) ? 2 * step : step_lo / 2 + step_hi / 2;
		}
	}
}