/**
File:		MachineLearning/GPModels/Models/Layers/GPLayers/FgGPBaseLayer.h

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#pragma once

#include <MachineLearning/FgILayer.h>
#include <MachineLearning/FgKernels.h>
#include <MachineLearning/FgGPLVMBaseModel.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		namespace GPModels
		{
			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Abstract class for different GP likelihood layers. </summary>
			///
			/// <remarks>	, 27.02.2018. </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			template<typename Scalar>
			class NE_IMPEXP GPBaseLayer : public ILayer<Scalar>
			{
			public:
				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Constructor. </summary>
				///
				/// <remarks>	, 26.04.2018. </remarks>
				///
				/// <param name="numPoints">	Number of training points. </param>
				/// <param name="outputDim">	The output dimension. </param>
				/// <param name="inputDim"> 	The input dimension. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				GPBaseLayer(int numPoints, int outputDim, int inputDim);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Destructor. </summary>
				///
				/// <remarks>	, 26.04.2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual ~GPBaseLayer();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets the kernel function. </summary>
				///
				/// <remarks>	, 26.04.2018. </remarks>
				///
				/// <returns>	null if it fails, else the kernel. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				IKernel<Scalar>* GetKernel();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets a kernel function. </summary>
				///
				/// <remarks>	, 26.04.2018. </remarks>
				///
				/// <param name="kernel">	[in,out] If non-null, the kernel. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				void SetKernel(IKernel<Scalar>* kern);

				virtual void InitParameters(af::array* X = nullptr);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets number of parameters to be optimized. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <returns>	The number parameters. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual int GetNumParameters();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets the parameters for each optimization iteration. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <param name="param">	The parameter. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void SetParameters(const af::array& param);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets the parameters for each optimization iteration. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <returns>	The parameters. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual af::array GetParameters();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Forward prediction of posterior function values. </summary>
				///
				/// <remarks>	, 12.06.2018. </remarks>
				///
				/// <param name="mout"> 	[in,out] The m^{\n}_{f}. </param>
				/// <param name="vout"> 	[in,out] The V^{\n}_{ff}. </param>
				/// <param name="mx">   	[in,out] The inputs mx. </param>
				/// <param name="vx">   	[in,out] (Optional) If non-null, the variances vx. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void ForwardPredictionPost(const af::array* mx, const af::array* vx, af::array& mout, af::array& vout);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Samples from posterior. </summary>
				///
				/// <remarks>	, 12.06.2018. </remarks>
				///
				/// <param name="vx">	[in,out] (Optional) If non-null, the variances vx. </param>
				///
				/// <param name="mx">	[in,out] The inputs mx. </param>
				///
				/// <param name="fsample">	[in,out] The m^{\n}_{f}. </param>
				/// <param name="inX">	  	The V^{\n}_{ff}. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void SampleFromPost(const af::array& inX, af::array& outfsample);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets fixation for hyperparameters. </summary>
				///
				/// <remarks>	Hmetal T, 16/12/2019. </remarks>
				///
				/// <param name="isfixed">	True if isfixed. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void FixKernelParameters(bool isfixed);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Updates the parameters. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void UpdateParameters();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets data size. </summary>
				///
				/// <remarks>	Hmetal T, 03/09/2020. </remarks>
				///
				/// <param name="length">   	The length. </param>
				/// <param name="dimension">	The dimension. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void SetDataSize(int length, int dimension) override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets the syles. </summary>
				///
				/// <remarks>	Hmetal T, 25/09/2020. </remarks>
				///
				/// <param name="styles">	[in,out] If non-null, the styles. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				void SetStyles(std::map<std::string, Style<Scalar>>* styles);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets latent dimension. </summary>
				///
				/// <remarks>	Hmetal T, 28/06/2022. </remarks>
				///
				/// <param name="q">	An int to process. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				void SetLatentDimension(int q);


			protected:
				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Default constructor. </summary>
				///
				/// <remarks>	Hmetal T, 02/07/2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				GPBaseLayer() { }

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Reinitializes the parameters. </summary>
				///
				/// <remarks>	Hmetal T, 03/09/2020. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void ReinitParameters();

				int iq; //!< Latent dimension

				bool isFixedHypers;

				IKernel<Scalar>* kernel;						//!< kernel function
				std::map<std::string, Style<Scalar>>* mStyles;	//!< style variable

				Scalar JITTER; //!< for kernel matrix stability (positive definiteness)

			private:
				friend class boost::serialization::access;

				template<class Archive>
				void serialize(Archive& ar, unsigned int version)
				{
					ar& boost::serialization::base_object<ILayer<Scalar>>(*this);
					//ar& boost::serialization::make_nvp("ILayer", boost::serialization::base_object<ILayer<Scalar>>(*this));

					ar.register_type<LinearKernel<Scalar>>();
					//ar.register_type<RBFKernel<Scalar>>();
					ar.register_type<ARDKernel<Scalar>>();
					//ar.register_type<WhiteKernel<Scalar>>();
					ar.register_type<CompoundKernel<Scalar>>();
					//ar.register_type<RBFAccelerationKernel<Scalar>>();
					//ar.register_type<LinearAccelerationKernel<Scalar>>();

					ar& BOOST_SERIALIZATION_NVP(iq);
					ar& BOOST_SERIALIZATION_NVP(kernel);
					ar& BOOST_SERIALIZATION_NVP(isFixedHypers);
					ar& BOOST_SERIALIZATION_NVP(JITTER);
					ar& BOOST_SERIALIZATION_NVP(mStyles);
				}
			};
		}
	}
}