/**
File:		MachineLearning/Optimization/Unconstrained/FgNadamSolver.h

Author:		
Email:		
Site:       

Copyright (c) 2019 . All rights reserved.
*/

#pragma once

#include <MachineLearning/BaseGradientOptimizationMethod.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		////////////////////////////////////////////////////////////////////////////////////////////////////
		/// <summary>	AdaMax optimizer. </summary>
		///
		/// <remarks>	
		/// 	<para>
		///			Nadam update rule. Nadam is an optimizer that combines the effect of Adam
		///			and NAG to the gradient descent to improve its Performance. As name suggests the idea 
		///			is to use Nesterov momentum term for the first moving averages. Lets take a look at 
		///			update rule of the SGD with momentum:
		///		</para>
		///		<para>
		///			$$m_t=\beta m_{t-1}+\nu g_t$$
		///			$$w_t=w_{t-1}-m_t=w_{t-1}\beta m_{t-1}+\nu g_t$$
		///		</para>
		///		<para>
		///			As shown above, the update rule is equivalent to taking a step in the direction of 
		///			momentum vector and then taking a step in the direction of gradient. However, the 
		///			momentum step doesnt depend on the current gradient , so we can get a higher-quality 
		///			gradient step direction by updating the parameters with the momentum step before 
		///			computing the gradient. To achieve that, we modify the update as follows:
		///		</para>
		///		<para>
		///			$$g_t=\nabla f(w_{t-1}-\beta m_{t-1})$$
		///			$$m_t=\beta m_{t-1}+\nu g_t$$
		///			$$w_t=w_{t-1}-m_t$$
		///		</para>
		///		<para>
		///			So, with Nesterov accelerated momentum we first make make a big jump in the direction 
		///			of the previous accumulated gradient and then measure the gradient where we ended 
		///			up to make a correction. The same method can be incorporated into Adam, by changing 
		///			the first moving average to a Nesterov accelerated momentum. One computation trick 
		///			can be applied here: instead of updating the parameters to make momentum step and 
		///			changing back again, we can achieve the same effect by applying the momentum step 
		///			of time step t + 1 only once, during the update of the previous time step t instead 
		///			of t + 1. Using this trick, the implementation of Nadam may look like this:
		///		</para>
		///		<code>
		///			for t in range(num_iterations):
		///				g = compute_gradient(x, y)
		///				m = beta_1 * m + (1 - beta_1) * g
		///				v = beta_2 * v + (1 - beta_2) * np.power(g, 2)
		///				m_hat = m / (1 - np.power(beta_1, t)) + (1 - beta_1) * g / (1 - np.power(beta_1, t))
		///				v_hat = v / (1 - np.power(beta_2, t))
		///				w = w - step_size * m_hat / (np.sqrt(v_hat) + epsilon)
		///		</code>
		/// 
		///		<para>
		/// 	  References:
		/// 	  <list type="bullet">
		///			<item>
		/// 	    	  <description><a href="https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ">
		/// 				Timothy Dozat (2015). "Incorporating Nesterov momentum into Adam". Stanford University.</a>
		/// 	       </description>
		/// 	    </item>
		/// 	   </list>
		/// 	</para>
		/// 	
		/// 	HmetalT, 02.05.2019. 
		/// </remarks>
		////////////////////////////////////////////////////////////////////////////////////////////////////
		template<typename Scalar, LineSearchType LSType = MoreThuente>
		class NE_IMPEXP NadamSolver : public BaseGradientOptimizationMethod<Scalar, LSType>
		{
		public:

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Creates a new instance of the L-BFGS optimization algorithm. </summary>
			///
			/// <remarks>	 Admin, 3/27/2017. </remarks>
			///
			/// <param name="numberOfVariables">
			/// 	The number of free parameters in the optimization problem.
			/// </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			NadamSolver(int numberOfVariables);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Creates a new instance of the L-BFGS optimization algorithm. </summary>
			///
			/// <remarks>	 Admin, 3/27/2017. </remarks>
			///
			/// <param name="numberOfVariables">
			/// 	The number of free parameters in the function to be optimized.
			/// </param>
			/// <param name="function">				[in,out] The function to be optimized. </param>
			/// <param name="gradient">				[in,out] The gradient of the function. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			NadamSolver(int numberOfVariables,
				std::function<Scalar(const af::array&, af::array&)> function);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Creates a new instance of the L-BFGS optimization algorithm. </summary>
			///
			/// <remarks>	 Admin, 3/27/2017. </remarks>
			///
			/// <param name="function">	The objective function and gradients whose optimum values should be found. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			NadamSolver(NonlinearObjectiveFunction<Scalar>* function);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Destructor. </summary>
			///
			/// <remarks>	, 15.08.2019. </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			~NadamSolver();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Sets decay rate for the first moment estimates. </summary>
			///
			/// <remarks>	, 15.08.2019. </remarks>
			///
			/// <param name="beta1">	The first beta. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void SetBeta1(Scalar beta1);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Sets decay rate for the second-moment estimates. </summary>
			///
			/// <remarks>	, 15.08.2019. </remarks>
			///
			/// <param name="beta2">	The second beta. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void SetBeta2(Scalar beta2);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Sets the learning rate. </summary>
			///
			/// <remarks>	, 15.08.2019. </remarks>
			///
			/// <param name="alpha">	The alpha. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void SetAlpha(Scalar alpha);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Sets an epsilon to avoid division by zero. </summary>
			///
			/// <remarks>	, 15.08.2019. </remarks>
			///
			/// <param name="epsilon">	The epsilon. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void SetEpsilon(Scalar epsilon);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Sets initial decay rate. </summary>
			///
			/// <remarks>	, 15.08.2019. </remarks>
			///
			/// <param name="decay">	The decay. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void SetDecay(Scalar decay);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets decay rate for the first moment estimates. </summary>
			///
			/// <remarks>	, 15.08.2019. </remarks>
			///
			/// <returns>	The beta 1. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			Scalar GetBeta1();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets decay rate for the second-moment estimates. </summary>
			///
			/// <remarks>	, 15.08.2019. </remarks>
			///
			/// <returns>	The beta 2. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			Scalar GetBeta2();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets the learning rate. </summary>
			///
			/// <remarks>	, 15.08.2019. </remarks>
			///
			/// <returns>	The alpha. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			Scalar GetAlpha();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets the epsilon. </summary>
			///
			/// <remarks>	, 15.08.2019. </remarks>
			///
			/// <returns>	The epsilon. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			Scalar GetEpsilon();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets the initial decay. </summary>
			///
			/// <remarks>	, 15.08.2019. </remarks>
			///
			/// <returns>	The decay. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			Scalar GetDecay();

		protected:

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>
			/// 	Implements the actual optimization algorithm. This method should try to minimize the
			/// 	objective function.
			/// </summary>
			///
			/// <remarks>	Hmetal T, 11.04.2017. </remarks>
			///
			/// <returns>	true if it succeeds, false if it fails. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual bool Optimize(int* cycle = nullptr) override;

		private:
			Scalar min_step;	// The minimum step length allowed in the line search.
			Scalar max_step;	// The maximum step length allowed in the line search.

			Scalar sAlpha;		// learning rate
			Scalar sBeta1;		// exponential decay rate for the first moment estimates (e.g. 0.9)
			Scalar sBeta2;		// exponential decay rate for the second-moment estimates (e.g. 0.999).
			Scalar sEpsilon;	// small number to prevent any division by zero in the implementation
			Scalar sDecay;
			Scalar delta;
			Scalar sCumBeta1;
		};
	}
}

