/**
File:		MachineLearning/Optimization/Unconstrained/Linesearch/Armijo.cpp

Author:		
Email:		
Site:       

Copyright (c) 2019 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgArmijoLineSearch.h>
//#include <MachineLearning/CommonUtil.h>
//#include <cmath>
//#include <math.h>
//#include <limits>
//#include <iomanip>

namespace NeuralEngine::MachineLearning
{
	template class ArmijoLineSearch<float>;
	template class ArmijoLineSearch<double>;

	template<typename Scalar>
	ArmijoLineSearch<Scalar>::ArmijoLineSearch()
	{
	}

	template<typename Scalar>
	void ArmijoLineSearch<Scalar>::Linesearch(NonlinearObjectiveFunction<Scalar>& objFunc, Scalar& fx, af::array& x, af::array& grad, Scalar& step, const af::array& drt, const af::array& xp)
	{
		// Decreasing and increasing factors
		const Scalar dec = 0.5;
		const Scalar inc = 2.1;

		const Scalar ftol = 1e-4;

		const int max_linesearch = 20;

		// Check the value of step
		if (step <= 0)
			std::invalid_argument("'step' must be positive");

		// Save the function value at the current x
		const Scalar fx_init = fx;
		// Projection of gradient on the search direction
		const Scalar dg_init = matmulTN(grad, drt).scalar<Scalar>();
		// Make sure d points to a descent direction
		if (dg_init > 0)
			std::logic_error("the moving direction increases the objective function value");

		Scalar dg_test = ftol * dg_init;
		Scalar width;

		int iter;
		for (iter = 0; iter < max_linesearch; iter++)
		{
			// x_{k+1} = x_k + step * d_k
			x = xp + step * drt;
			// Evaluate this candidate
			fx = objFunc.Value(x);
			grad = objFunc.Gradient(x);

			if (fx > fx_init + step * dg_test)
			{
				width = dec;
			}
			else {
				// Armijo condition is met
				//if (linesearch == LBFGS_LINESEARCH_BACKTRACKING_ARMIJO)
					break;

				//const Scalar dg = matmulTN(grad, drt).scalar<float>();
				//if (dg < wolfe * dg_init)
				//{
				//	width = inc;
				//}
				//else {
				//	// Regular Wolfe condition is met
				//	if (linesearch == LBFGS_LINESEARCH_BACKTRACKING_WOLFE)
				//		break;

				//	if (dg > -wolfe * dg_init)
				//	{
				//		width = dec;
				//	}
				//	else {
				//		// Strong Wolfe condition is met
				//		break;
				//	}
				//}
			}

			if (iter >= max_linesearch)
				throw std::runtime_error("the line search routine reached the maximum number of iterations");

			//if (step < min_step)
			//	throw std::runtime_error("the line search step became smaller than the minimum value allowed");

			//if (step > max_step)
			//	throw std::runtime_error("the line search step became larger than the maximum value allowed");

			step *= width;
		}
	}
}