/**
File:		MachineLearning/Optimization/Unconstrained/FgAdamSolver.cpp

Author:		
Email:		
Site:       

Copyright (c) 2019 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgAdamSolver.h>
#include <cmath>
#include <math.h>
#include <limits>
#include <iomanip>

namespace NeuralEngine::MachineLearning
{
	template class AdamSolver<float, ArmijoBacktracking>;
	template class AdamSolver<float, ArmijoBracketing>;
	template class AdamSolver<float, MoreThuente>;
	template class AdamSolver<float, StrongWolfeBacktracking>;
	template class AdamSolver<float, StrongWolfeBracketing>;
	template class AdamSolver<float, WolfeBacktracking>;
	template class AdamSolver<float, WolfeBracketing>;

	template class AdamSolver<double, ArmijoBacktracking>;
	template class AdamSolver<double, ArmijoBracketing>;
	template class AdamSolver<double, MoreThuente>;
	template class AdamSolver<double, StrongWolfeBacktracking>;
	template class AdamSolver<double, StrongWolfeBracketing>;
	template class AdamSolver<double, WolfeBacktracking>;
	template class AdamSolver<double, WolfeBracketing>;

	template<typename Scalar, LineSearchType LSType>
	AdamSolver<Scalar, LSType>::AdamSolver(int numberOfVariables)
		: BaseGradientOptimizationMethod<Scalar, LSType>(numberOfVariables),
		min_step(1e-20), max_step(1e+20), sBeta1(0.9), sBeta2(0.999), sAlpha(0.01), sEpsilon(1e-8), delta(1e-8), sDecay(0.0)
	{
	}

	template<typename Scalar, LineSearchType LSType>
	AdamSolver<Scalar, LSType>::AdamSolver(int numberOfVariables, std::function<Scalar(const af::array&, af::array&)> function)
		: BaseGradientOptimizationMethod<Scalar, LSType>(numberOfVariables, function),
		min_step(1e-20), max_step(1e+20), sBeta1(0.9), sBeta2(0.99), sAlpha(0.01), sEpsilon(1e-8), delta(1e-8), sDecay(0.0)
	{
	}

	template<typename Scalar, LineSearchType LSType>
	AdamSolver<Scalar, LSType>::AdamSolver(NonlinearObjectiveFunction<Scalar> * function)
		: BaseGradientOptimizationMethod<Scalar, LSType>(function),
		min_step(1e-20), max_step(1e+20), sBeta1(0.9), sBeta2(0.999), sAlpha(0.01), sEpsilon(1e-8), delta(1e-8), sDecay(0.0)
	{
	}

	template<typename Scalar, LineSearchType LSType>
	AdamSolver<Scalar, LSType>::~AdamSolver()
	{
	}

	template<typename Scalar, LineSearchType LSType>
	void AdamSolver<Scalar, LSType>::SetBeta1(Scalar beta1)
	{
		sBeta1 = beta1;
	}

	template<typename Scalar, LineSearchType LSType>
	void AdamSolver<Scalar, LSType>::SetBeta2(Scalar beta2)
	{
		sBeta2 = beta2;
	}

	template<typename Scalar, LineSearchType LSType>
	void AdamSolver<Scalar, LSType>::SetAlpha(Scalar alpha)
	{
		sAlpha = alpha;
	}

	template<typename Scalar, LineSearchType LSType>
	void AdamSolver<Scalar, LSType>::SetEpsilon(Scalar epsilon)
	{
		sEpsilon = epsilon;
	}

	template<typename Scalar, LineSearchType LSType>
	void AdamSolver<Scalar, LSType>::SetDecay(Scalar decay)
	{
		sDecay = decay;
	}

	template<typename Scalar, LineSearchType LSType>
	Scalar AdamSolver<Scalar, LSType>::GetBeta1()
	{
		return sBeta1;
	}

	template<typename Scalar, LineSearchType LSType>
	Scalar AdamSolver<Scalar, LSType>::GetBeta2()
	{
		return sBeta2;
	}

	template<typename Scalar, LineSearchType LSType>
	Scalar AdamSolver<Scalar, LSType>::GetAlpha()
	{
		return sAlpha;
	}

	template<typename Scalar, LineSearchType LSType>
	Scalar AdamSolver<Scalar, LSType>::GetEpsilon()
	{
		return sEpsilon;
	}

	template<typename Scalar, LineSearchType LSType>
	Scalar AdamSolver<Scalar, LSType>::GetDecay()
	{
		return sDecay;
	}

	template<typename Scalar, LineSearchType LSType>
	bool AdamSolver<Scalar, LSType>::Optimize(int* cycle)
	{
		const int n = GetNumberOfVariables();
		af::array x = GetSolution();

		Scalar fx = _function->Value(x);
		Scalar fpast, step, t, step_t;

		af::array m_grad = _function->Gradient(x);
		af::array m_cap;
		af::array v_cap;
		af::array update;

		af::array ms = af::constant(0.0, n, m_dtype);
		af::array vs = af::constant(0.0, n, m_dtype);

		Scalar xnorm = af::norm(x);
		Scalar gnorm = af::norm(m_grad);

		std::stringstream buffer;

		// print out some useful information, if specified
		if (_display)
		{
			std::cout << "Numerical Optimization via Adam\n=================================\n\n";
			std::cout << "Starting Value: " << fx << std::endl << std::endl;
			/*buffer << "Numerical Optimization via Adam\n=================================\n\n Starting Value: " << fx << std::endl << std::endl;
			LogInformation(buffer.str());*/
		}

		int k = 0;
		if (cycle)
			*cycle = (k + 1);
		do
		{
			fpast = fx;

			step = sAlpha;
			if (sDecay > 0) step *= (1.0 / (1.0 + sDecay * k));
			t = k + 1;
			step_t = step * (sqrt(1.0 - pow(sBeta2, t)) / (1.0 - pow(sBeta1, t)));

			// main computation
			ms = (sBeta1 * ms) + (1.0 - sBeta1) * m_grad;				// updates the moving averages of the gradient
			vs = (sBeta2 * vs) + (1.0 - sBeta2) * (m_grad * m_grad);	// updates the moving averages of the squared gradient

			m_cap = ms / (1.0 - (pow(sBeta1, t)));						// calculates the bias - corrected estimates
			v_cap = vs / (1.0 - (pow(sBeta2, t)));						// calculates the bias - corrected estimates

			update = m_cap / (sqrt(v_cap) + sEpsilon);

			x -= (step_t * update);										// updates the parameters

			// evaluate function and gradient
			fx = _function->Value(x);
			m_grad = _function->Gradient(x);

			if (_display)
			{
				std::cout << "Cycle: " << k + 1 << "\t\tf(x): " << fx << "\t\t\tStep Size: " << step_t << std::endl;
				/*buffer << "Cycle: " << k + 1 << "\t\tf(x): " << fx << "\t\t\tStep Size: " << step_t << std::endl;
				LogInformation(buffer.str());*/
			}

			if (cycle)
				*cycle = (k + 1);

			xnorm = af::norm(x);
			gnorm = af::norm(m_grad);
			k++;
		} while (gnorm > _tolerance * std::max<Scalar>(xnorm, 1.0) && k < maxIterations && abs((fpast - fx) / fx) >= delta  && fx > 0/* && fx  - 0.01 < fpast*/);

		SetSolution(x);

		std::cout << "Done." << std::endl << std::endl;

		return 1;
	}
}