/**
File:		MachineLearning/Optimization/Unconstrained/FgAdaMax.cpp

Author:		
Email:		
Site:       

Copyright (c) 2019 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgNadamSolver.h>
#include <cmath>
#include <math.h>
#include <limits>
#include <iomanip>

namespace NeuralEngine::MachineLearning
{
	template class NadamSolver<float, ArmijoBacktracking>;
	template class NadamSolver<float, ArmijoBracketing>;
	template class NadamSolver<float, MoreThuente>;
	template class NadamSolver<float, StrongWolfeBacktracking>;
	template class NadamSolver<float, StrongWolfeBracketing>;
	template class NadamSolver<float, WolfeBacktracking>;
	template class NadamSolver<float, WolfeBracketing>;

	template class NadamSolver<double, ArmijoBacktracking>;
	template class NadamSolver<double, ArmijoBracketing>;
	template class NadamSolver<double, MoreThuente>;
	template class NadamSolver<double, StrongWolfeBacktracking>;
	template class NadamSolver<double, StrongWolfeBracketing>;
	template class NadamSolver<double, WolfeBacktracking>;
	template class NadamSolver<double, WolfeBracketing>;

	template<typename Scalar, LineSearchType LSType>
	NadamSolver<Scalar, LSType>::NadamSolver(int numberOfVariables)
		: BaseGradientOptimizationMethod<Scalar, LSType>(numberOfVariables),
		min_step(1e-20), max_step(1e+20), sBeta1(0.9), sBeta2(0.999), sAlpha(0.001), sEpsilon(1e-8), delta(1e-8), sDecay(4e-3), sCumBeta1(1)
	{
	}

	template<typename Scalar, LineSearchType LSType>
	NadamSolver<Scalar, LSType>::NadamSolver(int numberOfVariables, std::function<Scalar(const af::array&, af::array&)> function)
		: BaseGradientOptimizationMethod<Scalar, LSType>(numberOfVariables, function),
		min_step(1e-20), max_step(1e+20), sBeta1(0.9), sBeta2(0.99), sAlpha(0.001), sEpsilon(1e-8), delta(1e-8), sDecay(4e-3), sCumBeta1(1)
	{
	}

	template<typename Scalar, LineSearchType LSType>
	NadamSolver<Scalar, LSType>::NadamSolver(NonlinearObjectiveFunction<Scalar> * function)
		: BaseGradientOptimizationMethod<Scalar, LSType>(function),
		min_step(1e-20), max_step(1e+20), sBeta1(0.9), sBeta2(0.999), sAlpha(0.001), sEpsilon(1e-8), delta(1e-8), sDecay(4e-3), sCumBeta1(1)
	{
	}

	template<typename Scalar, LineSearchType LSType>
	NadamSolver<Scalar, LSType>::~NadamSolver()
	{
	}

	template<typename Scalar, LineSearchType LSType>
	void NadamSolver<Scalar, LSType>::SetBeta1(Scalar beta1)
	{
		sBeta1 = beta1;
	}

	template<typename Scalar, LineSearchType LSType>
	void NadamSolver<Scalar, LSType>::SetBeta2(Scalar beta2)
	{
		sBeta2 = beta2;
	}

	template<typename Scalar, LineSearchType LSType>
	void NadamSolver<Scalar, LSType>::SetAlpha(Scalar alpha)
	{
		sAlpha = alpha;
	}

	template<typename Scalar, LineSearchType LSType>
	void NadamSolver<Scalar, LSType>::SetEpsilon(Scalar epsilon)
	{
		sEpsilon = epsilon;
	}

	template<typename Scalar, LineSearchType LSType>
	void NadamSolver<Scalar, LSType>::SetDecay(Scalar decay)
	{
		sDecay = decay;
	}

	template<typename Scalar, LineSearchType LSType>
	Scalar NadamSolver<Scalar, LSType>::GetBeta1()
	{
		return sBeta1;
	}

	template<typename Scalar, LineSearchType LSType>
	Scalar NadamSolver<Scalar, LSType>::GetBeta2()
	{
		return sBeta2;
	}

	template<typename Scalar, LineSearchType LSType>
	Scalar NadamSolver<Scalar, LSType>::GetAlpha()
	{
		return sAlpha;
	}

	template<typename Scalar, LineSearchType LSType>
	Scalar NadamSolver<Scalar, LSType>::GetEpsilon()
	{
		return sEpsilon;
	}

	template<typename Scalar, LineSearchType LSType>
	Scalar NadamSolver<Scalar, LSType>::GetDecay()
	{
		return sDecay;
	}

	template<typename Scalar, LineSearchType LSType>
	bool NadamSolver<Scalar, LSType>::Optimize(int* cycle)
	{
		const int n = GetNumberOfVariables();
		af::array x = GetSolution();

		Scalar fx = _function->Value(x);
		Scalar fpast, step, t, step_t;

		af::array m_grad = _function->Gradient(x);
		af::array ms(m_dtype);
		af::array vs(m_dtype);

		ms = af::constant(0.0, n, m_dtype);
		vs = af::constant(0.0, n, m_dtype);

		Scalar xnorm = af::norm(x);
		Scalar gnorm = af::norm(m_grad);

		// print out some useful information, if specified
		if (_display)
		{
			std::cout << "Numerical Optimization via Nadam\n=================================\n\n";
			std::cout << "Starting Value: " << fx << std::endl << std::endl;
		}

		int k = 0;
		if (cycle)
			*cycle = (k + 1);
		do
		{
			fpast = fx;

			step = sAlpha;
			if (sDecay > 0) step *= (1.0 / (1.0 + sDecay * k));
			t = k + 1;
			step_t = step * (sqrt(1.0 - pow(sBeta2, t)) / (1.0 - pow(sBeta1, t)));

			// main computation
			ms = (sBeta1 * ms) + (1.0 - sBeta1) * m_grad;				// updates the moving averages of the gradient
			vs = (sBeta2 * vs) + (1.0 - sBeta2) * (m_grad * m_grad);	// updates the moving averages of the squared gradient

			Scalar beta1T = sBeta1 * (1 - (0.5 *
				std::pow(0.96, t * sDecay)));

			Scalar beta1T1 = sBeta1 * (1 - (0.5 *
				std::pow(0.96, (t + 1) * sDecay)));

			sCumBeta1 *= beta1T;

			Scalar biasCorrection1 = 1.0 - sCumBeta1;

			Scalar biasCorrection2 = 1.0 - std::pow(sBeta2, t);

			Scalar biasCorrection3 = 1.0 - (sCumBeta1 * beta1T1);

			/* Note :- arma::sqrt(v) + epsilon * sqrt(biasCorrection2) is approximated
			* as arma::sqrt(v) + epsilon
			*/
			x -= (step_t * (((1 - beta1T) / biasCorrection1) * m_grad
				+ (beta1T1 / biasCorrection3) * ms) * sqrt(biasCorrection2))
				/ (sqrt(vs) + sEpsilon);

			fx = _function->Value(x);									// evaluate function and gradient
			m_grad = _function->Gradient(x);

			if (_display)
				std::cout << "Cycle: " << k + 1 << "\t\tf(x): " << fx << "\t\t\tStep Size: " << step_t << std::endl;

			if (cycle)
				*cycle = (k + 1);

			xnorm = af::norm(x);
			gnorm = af::norm(m_grad);
			k++;
		} while (gnorm > _tolerance * std::max<Scalar>(xnorm, 1.0) && k < maxIterations && (abs(fpast - fx) / fx) >= delta);

		SetSolution(x);
		return 1;
	}
}