﻿/**
File:		MachineLearning/Optimization/Unconstrained/L-BFGS.cpp

Author:		
Email:		
Site:       

Copyright (c) 2019 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgLBFGSsolver.h>
#include <cmath>
#include <math.h>
#include <limits>
#include <iomanip>

namespace NeuralEngine::MachineLearning
{
	template class LBFGSSolver<float, ArmijoBacktracking>;
	template class LBFGSSolver<float, ArmijoBracketing>;
	template class LBFGSSolver<float, MoreThuente>;
	template class LBFGSSolver<float, StrongWolfeBacktracking>;
	template class LBFGSSolver<float, StrongWolfeBracketing>;
	template class LBFGSSolver<float, WolfeBacktracking>;
	template class LBFGSSolver<float, WolfeBracketing>;

	template class LBFGSSolver<double, ArmijoBacktracking>;
	template class LBFGSSolver<double, ArmijoBracketing>;
	template class LBFGSSolver<double, MoreThuente>;
	template class LBFGSSolver<double, StrongWolfeBacktracking>;
	template class LBFGSSolver<double, StrongWolfeBracketing>;
	template class LBFGSSolver<double, WolfeBacktracking>;
	template class LBFGSSolver<double, WolfeBracketing>;

	template<typename Scalar, LineSearchType LSType>
	LBFGSSolver<Scalar, LSType>::LBFGSSolver(int numberOfVariables)
		: BaseGradientOptimizationMethod<Scalar, LSType>(numberOfVariables), m(6), past(1), delta(1e-10), max_linesearch(20),
		min_step(1e-20), max_step(1e+20), ftol(1e-4), wolfe(0.9)
	{
	}

	template<typename Scalar, LineSearchType LSType>
	LBFGSSolver<Scalar, LSType>::LBFGSSolver(int numberOfVariables, std::function<Scalar(const af::array&, af::array&)> function)
		: BaseGradientOptimizationMethod<Scalar, LSType>(numberOfVariables, function), m(6), past(1), delta(1e-10), max_linesearch(20),
		min_step(1e-20), max_step(1e+20), ftol(1e-4), wolfe(0.9)
	{
	}

	template<typename Scalar, LineSearchType LSType>
	LBFGSSolver<Scalar, LSType>::LBFGSSolver(NonlinearObjectiveFunction<Scalar> * function)
		: BaseGradientOptimizationMethod<Scalar, LSType>(function), m(6), past(1), delta(1e-10), max_linesearch(20),
		min_step(1e-20), max_step(1e+20), ftol(1e-4), wolfe(0.9)
	{
	}

	template<typename Scalar, LineSearchType LSType>
	LBFGSSolver<Scalar, LSType>::~LBFGSSolver()
	{
	}

	template<typename Scalar, LineSearchType LSType>
	void LBFGSSolver<Scalar, LSType>::SetNumCorrections(int corrections)
	{
		m = corrections;
	}

	template<typename Scalar, LineSearchType LSType>
	void LBFGSSolver<Scalar, LSType>::SetDelta(Scalar inDelta)
	{
		delta = inDelta;
	}

	template<typename Scalar, LineSearchType LSType>
	void LBFGSSolver<Scalar, LSType>::SetMaxLinesearch(int maxIter)
	{
		max_linesearch = maxIter;
	}

	template<typename Scalar, LineSearchType LSType>
	bool LBFGSSolver<Scalar, LSType>::Optimize(int* cycle)
	{
		af::array m_s(m_dtype);      // History of the s vectors
		af::array m_y(m_dtype);      // History of the y vectors
		af::array m_ys(m_dtype);     // History of the s'y values
		af::array m_alpha(m_dtype);  // History of the step lengths
		af::array m_fx(m_dtype);     // History of the objective function values
		af::array m_xp(m_dtype);     // Old x
		af::array m_grad(m_dtype);   // New gradient
		af::array m_gradp(m_dtype);  // Old gradient
		af::array m_drt(m_dtype);    // Moving direction
		af::array svec(m_dtype);
		af::array yvec(m_dtype);

		const int n = GetNumberOfVariables();

		af::array x = GetSolution();

		m_alpha = af::constant(0.0f, m, m_dtype);
		m_ys = af::constant(0.0f, m, m_dtype);
		m_y = af::constant(0.0f, n, m, m_dtype);
		m_s = af::constant(0.0f, n, m, m_dtype);

		const int fpast = past;
		//reset(n);

		// Evaluate function and compute gradient
		Scalar fx = _function->Value(x);
		m_grad = _function->Gradient(x);
		Scalar xnorm = af::norm(x);
		Scalar gnorm = af::norm(m_grad);

		const Scalar abs_eps = 0.0001;
		const Scalar rel_eps = static_cast<Scalar>(abs_eps) *
			std::max<Scalar>(Scalar{ 1.0 }, xnorm);

		m_fx = af::constant((Scalar)0.0, fpast, m_dtype);
		if (fpast > 0)
			m_fx(0) = fx;

		// Early exit if the initial x is already a minimizer
		if (gnorm <= _tolerance * std::max<Scalar>(xnorm, 1.0))
		{
			return 1;
		}

		// Initial direction
		m_drt = -m_grad.copy();
		// Initial step
		Scalar step = 1.0 / af::norm(m_drt);

		// print out some useful information, if specified
		if (_display)
		{
			std::cout << "Numerical Optimization via L-BFGS\n=================================\n\n";
			std::cout << "Starting Value: " << fx << std::endl << std::endl;
		}

		int k = 1;
		if (cycle)
			cycle = &k;

		int end = 0;
		int boundkount = 1;
		for (; ; )
		{
			// Save the curent x and gradient
			m_xp = x.copy();
			m_gradp = m_grad.copy();

			linesearch->Linesearch(*_function, fx, x, m_grad, step, m_drt, m_xp);

			// New x norm and gradient norm
			xnorm = af::norm(x);
			gnorm = af::norm(m_grad);

			//// Convergence test -- gradient
			//if (gnorm <= _tolerance * std::max<Scalar>(xnorm, 1.0))
			//{
			//	SetSolution(x);
			//	return k;
			//}
			// Convergence test -- objective function value
			if (fpast > 0)
			{
				if (k >= fpast && /*abs((m_fx(k % fpast).scalar<Scalar>() - fx) / fx) < delta*/gnorm <= _tolerance * std::max<Scalar>(xnorm, 1.0))
				{
					SetSolution(x);
					return 1;
				}

				m_fx(k % fpast) = fx;
			}
			// Maximum number of iterations
			if (maxIterations != 0 && k > maxIterations)
			{
				SetSolution(x);
				return 1;
			}

			if (_display)
				std::cout << "Cycle: " << k << "\t\tf(x): " << fx << "\t\t\tStep Size: " << step << std::endl;

			if (cycle)
				cycle = &k;


			// Update s and y
			// s_{k+1} = x_{k+1} - x_k
			// y_{k+1} = g_{k+1} - g_k		
			svec = x - m_xp;
			yvec = m_grad - m_gradp;

			// ys = y's = 1/rho
			// yy = y'y
			Scalar ys = matmulTN(yvec, svec).scalar<Scalar>();
			Scalar yy = CommonUtil<Scalar>::SquareEuclidean(yvec);

			// NaN check
			if (fx != fx)
			{
				std::cout << "NAN BREAK: Set previous solution" << std::endl;
				SetSolution(m_xp);
				return 1;
				/*std::cout << "***Negative curvature. Skipping***" << std::endl;
				x = m_xp.copy();
				m_grad = m_gradp.copy();

				m_alpha = constant(0.0f, m, m_dtype);
				m_ys = constant(0.0f, m, m_dtype);
				m_y = constant(0.0f, n, m, m_dtype);
				m_s = constant(0.0f, n, m, m_dtype);

				m_drt = -m_gradp;

				end = 0;
				step = 1.0;
				k++;
				continue;*/
			}

			m_ys(end) = ys;

			//af_print(m_grad);

			m_y.col(end) = yvec;
			m_s.col(end) = svec;

			// Recursive formula to compute d = -H * g
			m_drt = -m_grad;
			int bound = std::min(m, k);
			end = (end + 1) % m;
			int j = end;
			for (int i = 0; i < bound; i++)
			{
				j = (j + m - 1) % m;
				m_alpha(j) = matmulTN(m_s.col(j), m_drt) / m_ys(j);
				m_drt -= m_alpha(j).scalar<Scalar>() * m_y.col(j);
			}

			m_drt *= (ys / yy);

			for (int i = 0; i < bound; i++)
			{
				Scalar beta = (matmulTN(m_y.col(j), m_drt) / m_ys(j)).scalar<Scalar>();
				m_drt += (m_alpha(j).scalar<Scalar>() - beta) * m_s.col(j);
				j = (j + 1) % m;
			}

			Scalar descentDirection = -af::matmulTN(m_grad, m_drt).scalar<Scalar>();
			Scalar step = 1.0 / gnorm;
			if (descentDirection > -abs_eps * rel_eps)
			{
				m_drt = -m_grad.copy();
				//m_idx = 0;
				step = 1.0;
			}
			//step = 1.0; //as initial guess
			step = 1.0 / af::norm(m_drt);
			k++;
			boundkount++;
		}
		return false;
	}

	//template<typename Scalar, LineSearchType LSType>
	//bool LBFGSSolver<Scalar, LSType>::Optimize()
	//{
	//	int bound = 0;
	//	int k = 0;
	//	int m_idx = 0;
	//	const int n = GetNumberOfVariables();

	//	af::array xp, gradp;

	//	af::array m_alpha = af::constant(0.0f, m, m_dtype);
	//	af::array m_xDiff = af::constant(0.0f, n, m, m_dtype);
	//	af::array m_gradDiff = af::constant(0.0f, n, m, m_dtype);

	//	af::array x = GetSolution();

	//	// Evaluate function and compute gradient
	//	Scalar fx = _function->Value(x);
	//	af::array grad = _function->Gradient(x);
	//	af::array searchDirection = -grad.copy();

	//	Scalar fpast = 0.0;
	//	Scalar scalefactor = 1.0;
	//	Scalar xnorm = af::norm(x);
	//	Scalar gnorm = af::norm(grad);
	//	const Scalar abs_eps = 0.0001;
	//	const Scalar rel_eps = static_cast<Scalar>(abs_eps) *
	//		std::max<Scalar>(Scalar{ 1.0 }, xnorm);

	//	// Early exit if the initial x is already a minimizer
	//	if (gnorm <= _tolerance * std::max<Scalar>(xnorm, 1.0))
	//	{
	//		return 1;
	//	}

	//	// Initial step
	//	Scalar step = .1;// 1.0 / af::norm(m_drt);

	//	// print out some useful information, if specified
	//	if (_display)
	//	{
	//		std::cout << "Numerical Optimization via L-BFGS\n=================================\n\n";
	//		std::cout << "Starting Value: " << fx << std::endl << std::endl;
	//	}


	//	do
	//	{
	//		// Save the curent x and gradient
	//		xp = x.copy();
	//		gradp = grad.copy();
	//		fpast = fx;

	//		if (k > 0)
	//			bound = std::min<int>(m, m_idx - 1);

	//		for (int i = bound - 1; i >= 0; i--) 
	//		{
	//			// alpha_i <- rho_i*s_i^T*q
	//			const Scalar rho = 1.0 / (af::matmulTN(m_xDiff.col(i), m_gradDiff.col(i))).scalar<Scalar>();
	//			m_alpha(i) = rho * af::matmulTN(m_xDiff.col(i), searchDirection);
	//			// q <- q - alpha_i*y_i
	//			searchDirection -= m_alpha(i).scalar<Scalar>() * m_gradDiff.col(i);
	//		}

	//		// r <- H_k^0*q
	//		searchDirection = scalefactor * searchDirection;
	//		// for i k − m, k − m + 1, . . . , k − 1
	//		for (int i = 0; i < bound; i++) 
	//		{
	//			// beta <- rho_i * y_i^T * r
	//			const Scalar rho = 1.0 / (af::matmulTN(m_xDiff.col(i), m_gradDiff.col(i))).scalar<Scalar>();
	//			const Scalar beta = rho * (af::matmulTN(m_gradDiff.col(i), searchDirection)).scalar<Scalar>();
	//			// r <- r + s_i * ( alpha_i - beta)
	//			searchDirection += m_xDiff.col(i) * (m_alpha(i).scalar<Scalar>() - beta);
	//		}

	//		Scalar descentDirection = -af::matmulTN(grad, searchDirection).scalar<Scalar>();
	//		Scalar step = 1.0 / gnorm;
	//		if (descentDirection > -abs_eps * rel_eps) 
	//		{
	//			searchDirection = -grad.copy();
	//			m_idx = 0;
	//			step = 1.0;
	//		}

	//		linesearch->Linesearch(*_function, fx, x, grad, step, searchDirection, xp);

	//		// New x norm and gradient norm
	//		xnorm = af::norm(x);
	//		gnorm = af::norm(grad);

	//		// Convergence test -- gradient
	//		if (gnorm <= _tolerance * std::max<Scalar>(xnorm, 1.0))
	//		{
	//			SetSolution(x);
	//			return k;
	//		}

	//		// NaN check
	//		/*if (fx != fx)
	//		{
	//			std::cout << "NAN BREAK: Set previous solution" << std::endl;
	//			SetSolution(xp);
	//			return 1;
	//		}*/

	//		if (_display)
	//			std::cout << "Cycle: " << k << "\t\tf(x): " << fx << "\t\t\tStep Size: " << step << std::endl;


	//		af::array xDiff = x - xp;
	//		af::array gradDiff = grad - gradp;

	//		if(m_idx < m) {
	//			m_xDiff.col(m_idx) = xDiff;
	//			m_gradDiff.col(m_idx) = gradDiff;
	//		}
	//		else 
	//		{
	//			m_xDiff.cols(0, m - 2) = m_xDiff.cols(1, m - 1);
	//			m_gradDiff.cols(0, m - 2) = m_gradDiff.cols(1, m - 1);
	//			m_xDiff.col(m - 1) = xDiff;
	//			m_gradDiff.col(m - 1) = gradDiff;
	//		}

	//		k++;
	//		m_idx++;
	//	} while (k < maxIterations /*&& abs((fpast - fx) / fx) >= delta*/ && fx > 0);

	//	SetSolution(x);

	//	return true;
	//}
}