/**
File:		MachineLearning/Optimization/Base/IGradientOptimizationMethod.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#pragma once

#include <NeMachineLearningPCH.h>
#include <MachineLearning/BaseOptimizationMethod.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		template class BaseOptimizationMethod<float>;
		template class BaseOptimizationMethod<double>;

		template<typename Scalar>
		BaseOptimizationMethod<Scalar>::BaseOptimizationMethod(int numberOfVariables)
			: _display(true), m_dtype(CommonUtil<Scalar>::CheckDType())
		{
			if (numberOfVariables <= 0)
				LogError("numberOfVariables < 0");

			_function = new NonlinearObjectiveFunction<Scalar>(numberOfVariables);

			_x(m_dtype);

			init(numberOfVariables);
		}

		template<typename Scalar>
		BaseOptimizationMethod<Scalar>::BaseOptimizationMethod(int numberOfVariables, std::function<Scalar(const af::array&, af::array&)> function)
			: _display(true), m_dtype(CommonUtil<Scalar>::CheckDType())
		{
			if (function == nullptr)
				LogError("NullpointerExeption: function");

			_function = new NonlinearObjectiveFunction<Scalar>(numberOfVariables, function);

			_x(m_dtype);

			init(numberOfVariables);
			_function->SetFunction(function);
		}

		template<typename Scalar>
		BaseOptimizationMethod<Scalar>::BaseOptimizationMethod(NonlinearObjectiveFunction<Scalar>* function)
			: _display(true), m_dtype(CommonUtil<Scalar>::CheckDType())
		{
			if (function == nullptr)
				LogError("NullpointerExeption: function");

			_function = function;

			_x(m_dtype);

			init(function->GetNumberOfVariables());
			//SetFunction(function->GetFunction());
		}

		template<typename Scalar>
		void BaseOptimizationMethod<Scalar>::Display(bool display)
		{
			_display = display;
		}

		template<typename Scalar>
		BaseOptimizationMethod<Scalar>::~BaseOptimizationMethod() 
		{
			if (_function != nullptr) delete _function;
		}

		/*template<typename Scalar>
		std::function<Scalar(const af::array&)> BaseOptimizationMethod<Scalar>::GetFunction()
		{
			return _function->GetFunction();
		}

		template<typename Scalar>
		void BaseOptimizationMethod<Scalar>::SetFunction(std::function<Scalar(const af::array&)> f)
		{
			_function->SetFunction(f);
		}*/

		template<typename Scalar>
		int BaseOptimizationMethod<Scalar>::GetNumberOfVariables()
		{
			return _numVariables;
		}

		template<typename Scalar>
		void BaseOptimizationMethod<Scalar>::SetNumberOfVariables(int n)
		{
			_numVariables = n;
		}

		template<typename Scalar>
		af::array BaseOptimizationMethod<Scalar>::GetSolution()
		{
			return _x;
		}

		template<typename Scalar>
		void BaseOptimizationMethod<Scalar>::SetSolution(af::array& x)
		{
			if (x.isempty()) LogError("Empty Array");

			if (x.isrow()) LogError("Vector should be a column vector!");

			if (x.dims(0) != GetNumberOfVariables()) LogError("DimensionMismatchException value");

			_x = x;
		}

		template<typename Scalar>
		Scalar BaseOptimizationMethod<Scalar>::GetValue()
		{
			return _value;
		}

		template<typename Scalar>
		void BaseOptimizationMethod<Scalar>::SetValue(Scalar v)
		{
			_value = v;
		}

		template<typename Scalar>
		void BaseOptimizationMethod<Scalar>::init(int numberOfVariables)
		{
			SetNumberOfVariables(numberOfVariables);
			SetSolution(af::randu(numberOfVariables, m_dtype));
		}

		template<typename Scalar>
		bool BaseOptimizationMethod<Scalar>::Maximize(af::array& values, int* cycle)
		{
			SetSolution(values);
			return Maximize(cycle);
		}

		template<typename Scalar>
		bool BaseOptimizationMethod<Scalar>::Minimize(af::array& values, int* cycle)
		{
			SetSolution(values);
			return Minimize(cycle);
		}

		template<typename Scalar>
		bool BaseOptimizationMethod<Scalar>::Maximize(int* cycle)
		{
			if (_function->GetFunction() == nullptr) LogError("Invalid Operation: function");

			auto f = _function->GetFunction();

			f = [&f](af::array x, af::array grad) -> Scalar { return -f(x, grad); };

			_function->SetFunction(f);

			bool success = Optimize(cycle);

			//_function->SetFunction(f);

			_value = _function->Value(GetSolution());

			return success;
		}

		template<typename Scalar>
		bool BaseOptimizationMethod<Scalar>::Minimize(int* cycle)
		{
			if (_function->GetFunction() == nullptr) LogError("Invalid Operation: function");

			bool success = Optimize(cycle);

			_value = _function->Value(GetSolution());

			return success;
		}
	}
}