/**
File:		MachineLearning/GPModels/Models/FgIModel.h

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#pragma once

#include <NeMachineLearningLib.h>
#include <MachineLearning/IGradientOptimizationMethod.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		enum class ModelType
		{
			NONE = 0,
			GPR = 1,
			GPLVM = 2,
			DGPR = 3,
			DGPLVM = 4,
			SSM = 5
		};

		////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Description of the layer. </summary>
			///
			/// <remarks>	Hmetal T, 13/09/2019. </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
		class HiddenLayerDescription
		{
		public:

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Constructor. </summary>
			///
			/// <remarks>	Hmetal T, 13/09/2019. </remarks>
			///
			/// <param name="numPseudos">		  	Number of pseudo inputs. </param>
			/// <param name="numHiddenDimensions">	Number of hidden dimensions. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			HiddenLayerDescription(int numPseudos, int numHiddenDimensions)
			{
				iNumPseudos = numPseudos;
				iNumHidden = numHiddenDimensions;
			}

			HiddenLayerDescription()
			{
				iNumPseudos = 0;
				iNumHidden = 0;
			}

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets number pseudo inputs. </summary>
			///
			/// <remarks>	Hmetal T, 13/09/2019. </remarks>
			///
			/// <returns>	The number pseudo inputs. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			int GetNumPseudoInputs() { return iNumPseudos; }

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets number hidden dimensions. </summary>
			///
			/// <remarks>	Hmetal T, 13/09/2019. </remarks>
			///
			/// <returns>	The number hidden dimensions. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			int GetNumHiddenDimensions() { return iNumHidden; }

		private:
			int iNumPseudos;
			int iNumHidden;

			friend class boost::serialization::access;

			template<class Archive>
			void serialize(Archive& ar, unsigned int version)
			{
				//ar& boost::serialization::base_object<GPLVMBaseModel<Scalar>>(*this);

				ar& BOOST_SERIALIZATION_NVP(iNumPseudos);
				ar& BOOST_SERIALIZATION_NVP(iNumHidden);
			}
		};

		////////////////////////////////////////////////////////////////////////////////////////////////////
		/// <summary>	
		/// 	Base class with abstract and basic function definitions. All models will be derived 
		/// 	from this class.			
		/// </summary>
		///
		/// <remarks>	HmetalT, 26.10.2017. </remarks>
		////////////////////////////////////////////////////////////////////////////////////////////////////
		template<typename Scalar>
		class NE_IMPEXP IModel
		{
		public:

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Cost function the given x inputs. </summary>
			///
			/// <remarks>	Hmetal T, 29.11.2017. </remarks>
			///
			/// <param name="x">	   	The parameter input. </param>
			/// <param name="outGradient">	[in,out] The gradient. </param>
			///
			/// <returns>	A Scalar. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual Scalar Function(const af::array& x, af::array& outGradient);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets number of parameters to be optimized. </summary>
			///
			/// <remarks>	, 26.06.2018. </remarks>
			///
			/// <returns>	The number parameters. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual int GetNumParameters() = 0;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Sets the parameters for each optimization iteration. </summary>
			///
			/// <remarks>	, 26.06.2018. </remarks>
			///
			/// <param name="param">	The parameter. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void SetParameters(const af::array& param) = 0;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets the parameters for each optimization iteration. </summary>
			///
			/// <remarks>	, 26.06.2018. </remarks>
			///
			/// <param name="param">	The parameter. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual af::array GetParameters() = 0;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Updates the parameters. </summary>
			///
			/// <remarks>	, 26.06.2018. </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void UpdateParameters() = 0;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets data lenght. </summary>
			///
			/// <remarks>	Hmetal T, 16/04/2019. </remarks>
			///
			/// <returns>	The data lenght. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			int GetDataLenght();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets data dimensionality. </summary>
			///
			/// <remarks>	Hmetal T, 16/04/2019. </remarks>
			///
			/// <returns>	The data dimensionality. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			int GetDataDimensionality();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets model type. </summary>
			///
			/// <remarks>	Hmetal T, 16/04/2019. </remarks>
			///
			/// <returns>	The model type. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			ModelType GetModelType();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Sets batch size. </summary>
			///
			/// <remarks>	Hmetal T, 16/04/2019. </remarks>
			///
			/// <param name="size">	The size. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void SetBatchSize(int size);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets batch size. </summary>
			///
			/// <remarks>	Hmetal T, 16/04/2019. </remarks>
			///
			/// <returns>	The batch size. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			int GetBatchSize();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Sets the batch indexes. </summary>
			///
			/// <remarks>	Hmetal T, 31/08/2020. </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void SetIndexes(af::array& indexes);

		protected:

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Constructor. </summary>
			///
			/// <remarks>	Hmetal T, 16/04/2019. </remarks>
			///
			/// <param name="numData">	   	Number of data samples. </param>
			/// <param name="numDimension">	Number of data dimensions. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			IModel(int numData, int numDimension, ModelType type);

			

			ModelType mType;

			int iN;							//!< dataset length
			int iD;							//!< dataset dimension
			int iBatchSize;					//!< size of the batch
			af::array afIndexes;	//!< indexes of /f$\mathbf{X}/f$ for batch learning

			af::dtype m_dType;			//!< floating point precision flag for af::array

		private:
			friend class boost::serialization::access;

			template<class Archive>
			void serialize(Archive& ar, unsigned int version)
			{
				ar& BOOST_SERIALIZATION_NVP(iN);
				ar& BOOST_SERIALIZATION_NVP(iD);
				ar& BOOST_SERIALIZATION_NVP(iBatchSize);
				ar& BOOST_SERIALIZATION_NVP(afIndexes);
				ar& BOOST_SERIALIZATION_NVP(mType);
				ar& BOOST_SERIALIZATION_NVP(m_dType);
			}
		};

		////////////////////////////////////////////////////////////////////////////////////////////////////
		/// <summary>	Saves a model. </summary>
		///
		/// <remarks>	, 27.03.2018. </remarks>
		///
		/// <typeparam name="T">	Generic type parameter. </typeparam>
		/// <param name="file"> 	The file. </param>
		/// <param name="model">	[in,out] If non-null, the model. </param>
		////////////////////////////////////////////////////////////////////////////////////////////////////
		template<class T>
		void SaveModel(const std::string& file, T* model)
		{
			std::ofstream ofs(file.c_str(), std::ios::binary);
			boost::archive::binary_oarchive oa(ofs);
			oa << BOOST_SERIALIZATION_NVP(*model);
			ofs.close();
		}

		////////////////////////////////////////////////////////////////////////////////////////////////////
		/// <summary>	Loads a model. </summary>
		///
		/// <remarks>	, 27.03.2018. </remarks>
		///
		/// <typeparam name="T">	Generic type parameter. </typeparam>
		/// <param name="file">	The file. </param>
		///
		/// <returns>	null if it fails, else the model. </returns>
		////////////////////////////////////////////////////////////////////////////////////////////////////
		template<class T>
		T* LoadModel(const std::string& file)
		{
			T* model = new T();
			std::ifstream ifs(file.c_str(), std::ios::binary);
			boost::archive::binary_iarchive ia(ifs);
			ia >> BOOST_SERIALIZATION_NVP(*model);
			ifs.close();
			model->UpdateParameters();
			return model;
		}
	}
}
