/**
File:		MachineLearning/Optimization/Base/BaseGradientOptimizationMethod.h

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#pragma once

#include <NeMachineLearningLib.h>
#include <MachineLearning/BaseOptimizationMethod.h>
#include <MachineLearning/FgILineSearch.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{

		enum LineSearchType
		{
			ArmijoBacktracking,
			ArmijoBracketing,
			MoreThuente,
			StrongWolfeBacktracking,
			StrongWolfeBracketing,
			WolfeBacktracking,
			WolfeBracketing
		};

		////////////////////////////////////////////////////////////////////////////////////////////////////
		/// <summary>	Base class for gradient-based optimization methods. </summary>
		///
		/// <remarks>	HmetalT, 26.03.2017. </remarks>
		////////////////////////////////////////////////////////////////////////////////////////////////////
		template<typename Scalar, LineSearchType LSType = MoreThuente>
		class NE_IMPEXP BaseGradientOptimizationMethod : public BaseOptimizationMethod<Scalar>//, public IGradientOptimizationMethod<Scalar>
		{
		public:
			/*/// <summary>
			///   Gets or sets a cancellation token that can be used to
			///   stop the learning algorithm while it is running.
			/// </summary>
			///
			public CancellationToken Token{ get; set; }*/

			//////////////////////////////////////////////////////////////////////////////////////////////////////
			///// <summary>
			///// 	Gets the number of variables (free parameters)
			///// 	in the optimization problem.
			///// </summary>
			/////
			///// <value>	The number of parameters. </value>
			//////////////////////////////////////////////////////////////////////////////////////////////////////
			//virtual int GetNumberOfVariables() override;

			//////////////////////////////////////////////////////////////////////////////////////////////////////
			///// <summary>
			///// 	Gets the current solution found, the values of the parameters which optimizes the
			///// 	function.
			///// </summary>
			/////
			///// <remarks>	Hmetal T, 17.03.2017. </remarks>
			/////
			///// <returns>	The solution. </returns>
			//////////////////////////////////////////////////////////////////////////////////////////////////////
			//virtual af::array GetSolution() override;

			//////////////////////////////////////////////////////////////////////////////////////////////////////
			///// <summary>
			///// 	Sets the current solution found, the values of the parameters which optimizes the
			///// 	function.
			///// </summary>
			/////
			///// <remarks>	Hmetal T, 17.03.2017. </remarks>
			/////
			///// <returns>	The solution. </returns>
			//////////////////////////////////////////////////////////////////////////////////////////////////////
			//virtual void SetSolution(af::array& x) override;

			//////////////////////////////////////////////////////////////////////////////////////////////////////
			///// <summary>
			///// 	Finds the maximum value of a function. The solution vector will be made available at the
			///// 	<see cref="Solution"/> property.
			///// </summary>
			/////
			///// <remarks>	Hmetal T, 17.03.2017. </remarks>
			/////
			///// <param name="values">	The initial solution vector to start the search. </param>
			/////
			///// <returns>
			///// 	Returns <c>true</c> if the method converged to a <see cref="Solution"/>. In this case,
			///// 	the found value will also be available at the <see cref="Value"/>
			///// 	property.
			///// </returns>
			//////////////////////////////////////////////////////////////////////////////////////////////////////
			//virtual bool Maximize(af::array& values) override;

			//////////////////////////////////////////////////////////////////////////////////////////////////////
			///// <summary>
			///// 	Finds the minimum value of a function. The solution vector will be made available at the
			///// 	<see cref="Solution"/> property.
			///// </summary>
			/////
			///// <remarks>	Hmetal T, 17.03.2017. </remarks>
			/////
			///// <param name="values">	The initial solution vector to start the search. </param>
			/////
			///// <returns>
			///// 	Returns <c>true</c> if the method converged to a <see cref="Solution"/>. In this case,
			///// 	the found value will also be available at the <see cref="Value"/>
			///// 	property.
			///// </returns>
			//////////////////////////////////////////////////////////////////////////////////////////////////////
			//virtual bool Minimize(af::array& values) override;

			//////////////////////////////////////////////////////////////////////////////////////////////////////
			///// <summary>
			///// 	Finds the maximum value of a function. The solution vector will be made available at the
			///// 	<see cref="IOptimizationMethod.Solution"/> property.
			///// </summary>
			/////
			///// <remarks>	Hmetal T, 26.03.2017. </remarks>
			/////
			///// <returns>
			///// 	Returns <c>true</c> if the method converged to a
			///// 	<see cref="IOptimizationMethod.Solution"/>. In this case, the found value will also be
			///// 	available at the <see cref="IOptimizationMethod.Value"/>
			///// 	property.
			///// </returns>
			//////////////////////////////////////////////////////////////////////////////////////////////////////
			//virtual bool Maximize() override;

			//////////////////////////////////////////////////////////////////////////////////////////////////////
			///// <summary>
			///// 	Finds the minimum value of a function. The solution vector will be made available at the
			///// 	<see cref="IOptimizationMethod.Solution"/> property.
			///// </summary>
			/////
			///// <remarks>	Hmetal T, 26.03.2017. </remarks>
			/////
			///// <returns>
			///// 	Returns <c>true</c> if the method converged to a
			///// 	<see cref="IOptimizationMethod.Solution"/>. In this case, the found value will also be
			///// 	available at the <see cref="IOptimizationMethod.Value"/>
			///// 	property.
			///// </returns>
			//////////////////////////////////////////////////////////////////////////////////////////////////////
			//virtual bool Minimize() override;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>
			/// 	Gets the relative difference threshold to be used as stopping criteria between two
			/// 	iterations. Default is 0 (iterate until convergence).
			/// </summary>
			///
			/// <remarks>	 Admin, 3/27/2017. </remarks>
			///
			/// <returns>	The tolerance. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			Scalar GetTolerance();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>
			/// 	Sets the relative difference threshold to be used as stopping criteria between two
			/// 	iterations. Default is 0 (iterate until convergence).
			/// </summary>
			///
			/// <remarks>	 Admin, 3/27/2017. </remarks>
			///
			/// <returns>	The tolerance. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void SetTolerance(Scalar tolerance);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>
			/// 	Gets the maximum number of iterations to be performed during optimization. Default is 0
			/// 	(iterate until convergence).
			/// </summary>
			///
			/// <remarks>	 Admin, 3/27/2017. </remarks>
			///
			/// <returns>	The maximum iterations. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			int GetMaxIterations();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>
			/// 	Sets the maximum number of iterations to be performed during optimization. Default is 0
			/// 	(iterate until convergence).
			/// </summary>
			///
			/// <remarks>	 Admin, 3/27/2017. </remarks>
			///
			/// <param name="iter">	The iterator. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void SetMaxIterations(int iter);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>
			/// 	Gets the number of iterations performed in the last call to
			/// 	<see cref="IOptimizationMethod.Minimize()"/>.
			/// </summary>
			///
			/// <remarks>	 Admin, 3/27/2017. </remarks>
			///
			/// <returns>	The number of iterations performed in the previous optimization. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			int GetIterations();

			~BaseGradientOptimizationMethod();

		protected:

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>
			/// 	Initializes a new instance of the <see cref="BaseGradientOptimizationMethod"/> class.
			/// </summary>
			///
			/// <remarks>	Hmetal T, 26.03.2017. </remarks>
			///
			/// <param name="numberOfVariables">
			/// 	The number of free parameters in the optimization problem.
			/// </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			BaseGradientOptimizationMethod(int numberOfVariables);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>
			/// 	Initializes a new instance of the <see cref="BaseGradientOptimizationMethod"/> class.
			/// </summary>
			///
			/// <remarks>	Hmetal T, 26.03.2017. </remarks>
			///
			/// <param name="numberOfVariables">
			/// 	The number of free parameters in the optimization problem.
			/// </param>
			/// <param name="function">
			/// 	[in,out] The objective function whose optimum values should be found.
			/// </param>
			/// <param name="gradient">
			/// 	[in,out] The gradient of the objective <paramref name="function"/>.
			/// </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			BaseGradientOptimizationMethod(int numberOfVariables,
				std::function<Scalar(const af::array&, af::array&)> function);//, std::function<af::array(const af::array&)> gradient);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>
			/// 	Initializes a new instance of the <see cref="BaseGradientOptimizationMethod"/> class.
			/// </summary>
			///
			/// <remarks>	Hmetal T, 17.03.2017. </remarks>
			///
			/// <param name="function">	The objective function and gradients whose optimum values should be found. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			BaseGradientOptimizationMethod(NonlinearObjectiveFunction<Scalar>* function);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Inits linesearch. </summary>
			///
			/// <remarks>	Hmetal T, 11/06/2019. </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void InitLinesearch();

			//std::function<af::array(af::array&)> _gradient;

			int maxIterations;
			Scalar _tolerance;
			int iterations;

			ILineSearch<Scalar>* linesearch;
		};
	}
}
