/**
File:		MachineLearning/GPModels/SparseGPModels/SparseGPBaseLayer.h

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#pragma once

#include <MachineLearning/FgGPBaseLayer.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		namespace GPModels
		{
			enum PropagationMode
			{
				MomentMatching,
				Linear,
				MonteCarlo
			};

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Abstract class for different GP likelihood layers. </summary>
			///
			/// <remarks>	, 27.02.2018. </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			template<typename Scalar>
			class NE_IMPEXP SparseGPBaseLayer : public GPBaseLayer<Scalar>
			{
			public:

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Constructor. </summary>
				///
				/// <remarks>	, 15.05.2018. </remarks>
				///
				/// <param name="numPoints"> 	Number of points. </param>
				/// <param name="numPseudos">	Number of pseudo inputs. </param>
				/// <param name="outputDim"> 	The output dimension. </param>
				/// <param name="inputDim">  	The input dimension. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				SparseGPBaseLayer(int numPoints, int numPseudos, int outputDim, int inputDim);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Destructor. </summary>
				///
				/// <remarks>	, 15.05.2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual ~SparseGPBaseLayer();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Forward prediction of posterior function values. </summary>
				///
				/// <remarks>	, 12.06.2018. </remarks>
				///
				/// <param name="mout"> 	[in,out] The m^{\n}_{f}. </param>
				/// <param name="vout"> 	[in,out] The V^{\n}_{ff}. </param>
				/// <param name="mx">   	[in,out] The inputs mx. </param>
				/// <param name="vx">   	[in,out] (Optional) If non-null, the variances vx. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void ForwardPredictionPost(const af::array* mx, const af::array* vx, af::array& mout, af::array& vout) override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Samples from posterior. </summary>
				///
				/// <remarks>	, 12.06.2018. </remarks>
				///
				/// <param name="vx">	[in,out] (Optional) If non-null, the variances vx. </param>
				///
				/// <param name="mx">	[in,out] The inputs mx. </param>
				///
				/// <param name="fsample">	[in,out] The m^{\n}_{f}. </param>
				/// <param name="inX">	  	The V^{\n}_{ff}. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void SampleFromPost(const af::array& inX, af::array& outfsample) override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Calculates the kernel matrix of pseudo inputs. </summary>
				///
				/// <remarks>	, 15.05.2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				void ComputeKuu();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Calculates the kernel matrix of inputes and pseudo inputs. </summary>
				///
				/// <remarks>	, 15.05.2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				void ComputeKfu(const af::array& inX);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets pseudo inputs. </summary>
				///
				/// <remarks>	Hmetal T, 17/06/2019. </remarks>
				///
				/// <returns>	The pseudo inputs. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				af::array GetPseudoInputs();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets number of parameters to be optimized. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <returns>	The number parameters. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual int GetNumParameters() override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets the parameters for each optimization iteration. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <param name="param">	The parameter. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void SetParameters(const af::array& param) override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets the parameters for each optimization iteration. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <returns>	The parameters. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual af::array GetParameters() override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Set fixation for inducing inputs. </summary>
				///
				/// <remarks>	Hmetal T, 16/12/2019. </remarks>
				///
				/// <param name="isfixed">	True if isfixed. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void FixInducing(bool isfixed);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Updates the parameters. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void UpdateParameters() override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Initializes the parameters. </summary>
				///
				/// <remarks>	Hmetal T, 09/12/2019. </remarks>
				///
				/// <param name="X">	[in,out] (Optional) If non-null, an af::array to process. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void InitParameters(af::array* X = nullptr) override;

			protected:

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Default constructor. </summary>
				///
				/// <remarks>	Hmetal T, 02/07/2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				SparseGPBaseLayer() { }

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Reinitializes the parameters. </summary>
				///
				/// <remarks>	Hmetal T, 03/09/2020. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void ReinitParameters() override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Deterministic forward propagation through posterior. </summary>
				///
				/// <remarks>	Hmetal T, 01.04.2019. </remarks>
				///
				/// <param name="mx">  	[in,out] The inputs mx. </param>
				/// <param name="mout">	[in,out] The m^{\n}_{f}. </param>
				/// <param name="vout">	[in,out] The V^{\n}_{ff}. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void ForwardPredictionDeterministicPost(const af::array& mx, af::array* mout, af::array* vout);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Forward prediction through random posterior. </summary>
				///
				/// <remarks>	, 16.05.2018. </remarks>
				///
				/// <param name="mx">  	The inputs mx. </param>
				/// <param name="vx">  	If non-null, the variances vx. </param>
				/// <param name="mout">	[in,out] The m^{\n}_{f}. </param>
				/// <param name="vout">	[in,out] The V^{\n}_{ff}. </param>
				/// <param name="mode">	(Optional) Propagation mode. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void ForwardPredictionRandomPost(const af::array& mx, const af::array& vx, af::array& mout, af::array& vout, PropagationMode mode = PropagationMode::MomentMatching);

				int ik;					//< number pseudo inputs

				af::array afMu;			//< mean of /f$q(\mathbf{u})/f$
				af::array afSu;			//< covariance of /f$q(\mathbf{u})/f$
				af::array afInvSu;		//< natural parameter /f$\mathbf{K_{uu}} + T_{2,\mathbf{u}}/f$
				af::array afInvSuMu;	//< natural parameter /f$T_{1,\mathbf{u}}/f$

				// natural parameters
				af::array T1;			//< natural parameter /f$T_{1,\mathbf{u}}/f$
				af::array T2;			//< natural parameter /f$T_{2,\mathbf{u}}/f$
				af::array T2_R;			//< triangular /f$T_{2,\mathbf{u}}/f$ for parameter optimization

				af::array afXu;			//< inducing/pseudo inputs
				af::array afKuu;		//< kernel matrix inducing inputs
				af::array afInvKuu;		//< inverse kernel matrix inducing inputs
				af::array afKfu;		//< kernel matrix inputs and inducing inputs

				bool isFixedInducing;

			private:
				friend class boost::serialization::access;

				template<class Archive>
				void serialize(Archive& ar, unsigned int version)
				{
					ar& boost::serialization::base_object<GPBaseLayer<Scalar>>(*this);
					//ar& boost::serialization::make_nvp("GPBaseLayer", boost::serialization::base_object<GPBaseLayer<Scalar>>(*this));

					ar& BOOST_SERIALIZATION_NVP(ik);
					ar& BOOST_SERIALIZATION_NVP(afXu);
					/*ar& BOOST_SERIALIZATION_NVP(afKuu);
					ar& BOOST_SERIALIZATION_NVP(afInvKuu);
					ar& BOOST_SERIALIZATION_NVP(afMu);
					ar& BOOST_SERIALIZATION_NVP(afSu);
					ar& BOOST_SERIALIZATION_NVP(afInvSu);
					ar& BOOST_SERIALIZATION_NVP(afInvSuMu);
					ar& BOOST_SERIALIZATION_NVP(afKfu);*/
					ar& BOOST_SERIALIZATION_NVP(T1);
					ar& BOOST_SERIALIZATION_NVP(T2);
					ar& BOOST_SERIALIZATION_NVP(T2_R);
					ar& BOOST_SERIALIZATION_NVP(isFixedInducing);
				}
			};
		}
	}
}