/**
File:		MachineLearning/Optimization/Base/BaseGradientOptimizationMethod.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/BaseGradientOptimizationMethod.h>
#include <MachineLearning/FgArmijoLineSearch.h>
#include <MachineLearning/FgArmijoBracketingLineSearch.h>
#include <MachineLearning/FgMoreThuenteLineSearch.h>
#include <MachineLearning/FgStrongWolfeBacktrackingLineSearch.h>
#include <MachineLearning/FgStrongWolfeBracketingLineSearch.h>
#include <MachineLearning/FgWolfeBacktrackingLineSearch.h>
#include <MachineLearning/FgWolfeBracketingLineSearch.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		template class BaseGradientOptimizationMethod<float, ArmijoBacktracking>;
		template class BaseGradientOptimizationMethod<float, ArmijoBracketing>;
		template class BaseGradientOptimizationMethod<float, MoreThuente>;
		template class BaseGradientOptimizationMethod<float, StrongWolfeBacktracking>;
		template class BaseGradientOptimizationMethod<float, StrongWolfeBracketing>;
		template class BaseGradientOptimizationMethod<float, WolfeBacktracking>;
		template class BaseGradientOptimizationMethod<float, WolfeBracketing>;

		template class BaseGradientOptimizationMethod<double, ArmijoBacktracking>;
		template class BaseGradientOptimizationMethod<double, ArmijoBracketing>;
		template class BaseGradientOptimizationMethod<double, MoreThuente>;
		template class BaseGradientOptimizationMethod<double, StrongWolfeBacktracking>;
		template class BaseGradientOptimizationMethod<double, StrongWolfeBracketing>;
		template class BaseGradientOptimizationMethod<double, WolfeBacktracking>;
		template class BaseGradientOptimizationMethod<double, WolfeBracketing>;

		template<typename Scalar, LineSearchType LSType>
		BaseGradientOptimizationMethod<Scalar, LSType>::BaseGradientOptimizationMethod(int numberOfVariables)
			: BaseOptimizationMethod<Scalar>(numberOfVariables),
			iterations(0),
			maxIterations(0),
			_tolerance(0.0)
		{
			InitLinesearch();
		}

		template<typename Scalar, LineSearchType LSType>
		BaseGradientOptimizationMethod<Scalar, LSType>::BaseGradientOptimizationMethod(int numberOfVariables,
			std::function<Scalar(const af::array&, af::array&)> function)//, std::function<af::array(const af::array&)> gradient)
			: BaseOptimizationMethod<Scalar>(numberOfVariables, function),
			iterations(0),
			maxIterations(0),
			_tolerance(0.0)
		{
			//_function->SetGradient(gradient);
			InitLinesearch();
		}

		template<typename Scalar, LineSearchType LSType>
		BaseGradientOptimizationMethod<Scalar, LSType>::BaseGradientOptimizationMethod(NonlinearObjectiveFunction<Scalar>* function)
			: BaseOptimizationMethod<Scalar>(function),
			iterations(0),
			maxIterations(0),
			_tolerance(0.0)
		{
			InitLinesearch();
			//SetGradient(function->GetGradient());
		}

		template<typename Scalar, LineSearchType LSType>
		void BaseGradientOptimizationMethod<Scalar, LSType>::InitLinesearch()
		{
			switch (LSType)
			{
			case NeuralEngine::MachineLearning::ArmijoBacktracking:
				linesearch = new ArmijoLineSearch<Scalar>();
				break;
			case NeuralEngine::MachineLearning::ArmijoBracketing:
				linesearch = new ArmijoBracketingLineSearch<Scalar>();
				break;
			case NeuralEngine::MachineLearning::MoreThuente:
				linesearch = new MoreThuenteLineSearch<Scalar>();
				break;
			case NeuralEngine::MachineLearning::StrongWolfeBacktracking:
				linesearch = new StrongWolfeBacktrackingLineSearch<Scalar>();
				break;
			case NeuralEngine::MachineLearning::StrongWolfeBracketing:
				linesearch = new StrongWolfeBracketingLineSearch<Scalar>();
				break;
			case NeuralEngine::MachineLearning::WolfeBacktracking:
				linesearch = new WolfeBacktrackingLineSearch<Scalar>();
				break;
			case NeuralEngine::MachineLearning::WolfeBracketing:
				linesearch = new WolfeBracketingLineSearch<Scalar>();
				break;
			}
		}

		template<typename Scalar, LineSearchType LSType>
		BaseGradientOptimizationMethod<Scalar, LSType>::~BaseGradientOptimizationMethod() 
		{ 
			if (linesearch != nullptr) delete linesearch;
		}


		//template<typename Scalar, LineSearchType LSType>
		//int BaseGradientOptimizationMethod<Scalar, LSType>::GetNumberOfVariables() { return BaseOptimizationMethod<Scalar>::GetNumberOfVariables(); }

		//template<typename Scalar, LineSearchType LSType>
		//af::array BaseGradientOptimizationMethod<Scalar, LSType>::GetSolution() { return BaseOptimizationMethod<Scalar>::GetSolution(); }

		//template<typename Scalar, LineSearchType LSType>
		//void BaseGradientOptimizationMethod<Scalar, LSType>::SetSolution(af::array& x) { BaseOptimizationMethod<Scalar>::SetSolution(x); }

		//template<typename Scalar, LineSearchType LSType>
		//Scalar BaseGradientOptimizationMethod<Scalar, LSType>::GetValue() { return BaseOptimizationMethod<Scalar>::GetValue(); }

		//template<typename Scalar, LineSearchType LSType>
		//bool BaseGradientOptimizationMethod<Scalar, LSType>::Maximize(af::array& values) { return BaseOptimizationMethod::Maximize(values); }

		//template<typename Scalar, LineSearchType LSType>
		//bool BaseGradientOptimizationMethod<Scalar, LSType>::Minimize(af::array& values) { return BaseOptimizationMethod::Minimize(values); }

		//template<typename Scalar, LineSearchType LSType>
		//bool BaseGradientOptimizationMethod<Scalar, LSType>::Maximize()
		//{
		//	/*if (Gradient == null)
		//		throw new InvalidOperationException("gradient");*/

		//	_function->CheckGradient(GetSolution());

		//	auto g = GetGradient();

		//	_function->SetGradient([&g](af::array x)->af::array { return g(x) * -1; });

		//	bool success = BaseOptimizationMethod::Maximize();

		//	//_gradient = g;

		//	return success;
		//}

		//template<typename Scalar, LineSearchType LSType>
		//bool BaseGradientOptimizationMethod<Scalar, LSType>::Minimize()
		//{
		//	/*if (Gradient == null)
		//		throw new InvalidOperationException("gradient");*/

		//	_function->CheckGradient(GetSolution());

		//	return BaseOptimizationMethod<Scalar>::Minimize();
		//}

		template<typename Scalar, LineSearchType LSType>
		Scalar BaseGradientOptimizationMethod<Scalar, LSType>::GetTolerance()
		{
			return _tolerance;
		}

		template<typename Scalar, LineSearchType LSType>
		void BaseGradientOptimizationMethod<Scalar, LSType>::SetTolerance(Scalar tolerance)
		{
			_tolerance = tolerance;
		}

		template<typename Scalar, LineSearchType LSType>
		int BaseGradientOptimizationMethod<Scalar, LSType>::GetMaxIterations()
		{
			return maxIterations;
		}

		template<typename Scalar, LineSearchType LSType>
		void BaseGradientOptimizationMethod<Scalar, LSType>::SetMaxIterations(int iter)
		{
			maxIterations = iter;
		}

		template<typename Scalar, LineSearchType LSType>
		int BaseGradientOptimizationMethod<Scalar, LSType>::GetIterations()
		{
			return iterations;
		}
	}
}