/**
File:		MachineLearning/Models/GPModels/FgSparseDeepGPSSMBaseModel.h

Author:		
Email:		
Site:       

Copyright (c) 2022 . All rights reserved.
*/

#pragma once

#include <MachineLearning/FgGPStateSpaceBaseModel.h>
#include <MachineLearning/FgSparseGPBaseLayer.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		namespace GPModels
		{
			namespace PowerEP
			{
				template<typename Scalar>
				class SGPLayer;
			}

			namespace AEP
			{
				template<typename Scalar>
				class SGPLayer;
			}

			//namespace VFE
			//{
			//	//template<Scalar>
			//	class SGPLayer;
			//}

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	
			/// 	Base class with abstract and basic function definitions. All deep GP models will be derived 
			/// 	from this class.			
			/// </summary>
			///
			/// <remarks>	HmetalT, 26.10.2017. </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			template<typename Scalar>
			class NE_IMPEXP SparseDeepGPSSMBaseModel : public GPSSBaseModel<Scalar>
			{
			public:

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Constructor. </summary>
				///
				/// <remarks>	, 26.03.2018. </remarks>
				///
				/// <param name="Y">		  	The data af::array to process. </param>
				/// <param name="X">		  	The training inputs. </param>
				/// <param name="hiddenLayerdescription">	The description for one hidden layer. </param>
				/// <param name="lType">	  	(Optional) the loglik type. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				SparseDeepGPSSMBaseModel(const af::array& Y, int latentDimension, HiddenLayerDescription description,
					Scalar priorMean = 0.0, Scalar priorVariance = 1.0, af::array& xControl = af::array(), PropagationMode probMode = PropagationMode::MomentMatching,
					LogLikType lType = LogLikType::Gaussian, XInit emethod = XInit::pca);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Constructor. </summary>
				///
				/// <remarks>	, 26.03.2018. </remarks>
				///
				/// <param name="Y">		   	The data af::array to process. </param>
				/// <param name="X">		   	The training inputs. </param>
				/// <param name="hiddenLayerdescriptions">	The hidden layer descriptions. </param>
				/// <param name="lType">	   	(Optional) the loglik type. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				SparseDeepGPSSMBaseModel(const af::array& Y, int latentDimension, std::vector<HiddenLayerDescription> descriptions,
					Scalar priorMean = 0.0, Scalar priorVariance = 1.0, af::array& xControl = af::array(), PropagationMode probMode = PropagationMode::MomentMatching,
					LogLikType lType = LogLikType::Gaussian, XInit emethod = XInit::pca);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Default Constructor. </summary>
				///
				/// <remarks>	Hmetal T, 29.11.2017. </remarks>
				///
				////////////////////////////////////////////////////////////////////////////////////////////////////
				SparseDeepGPSSMBaseModel();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Destructor. </summary>
				///
				/// <remarks>	, 23.04.2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual ~SparseDeepGPSSMBaseModel();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Initializes the model. </summary>
				///
				/// <remarks>	Hmetal T, 29.11.2017. </remarks>
				///
				/// <returns>	true if it succeeds, false if it fails. </returns>
					////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual bool Init() override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Predict noise free functions values \f$\mathbf{F}_*\f$. </summary>
				///
				/// <remarks>	Hmetal T, 05/05/2020. </remarks>
				///
				/// <param name="testInputs">	The test inputs. </param>
				/// <param name="mf">		 	[in,out] mean of function values. </param>
				/// <param name="vf">		 	[in,out] The variance of function values. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void PredictF(const af::array& testInputs, af::array& mf, af::array& vf) override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Noise-free Forward prediction. </summary>
				///
				/// <remarks>	Hmetal T, 07/05/2020. </remarks>
				///
				/// <param name="numTimeSamples">	Number of time samples. </param>
				/// <param name="mf">			 	[in,out] The mean uf \f$\matthbf{Y}_*\f$. </param>
				/// <param name="vf">			 	[in,out] The variance of \f$\matthbf{Y}_*\f$. </param>
				/// <param name="mx">			 	[in,out] (Optional) If non-null, the mx. </param>
				/// <param name="vx">			 	[in,out] (Optional) If non-null, the vx. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void PredictForward(int numTimeSamples, af::array& my, af::array& vy, int numSamples = 200,
					af::array* mx = nullptr, af::array* vx = nullptr);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Generate function samples from posterior. </summary>
				///
				/// <remarks>	Hmetal T, 18/06/2019. </remarks>
				///
				/// <param name="outFunctions">	[in,out] The out functions. </param>
				/// <param name="inputs">	   	The inputs. </param>
				/// <param name="numSamples">  	Number of samples. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				//virtual void SampleY(const af::array inputs, int numSamples, af::array& outFunctions) override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets number of parameters. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <returns>	The number parameters. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual int GetNumParameters() override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets number gp layer parameters. </summary>
				///
				/// <remarks>	HmetalT, 31/03/2020. </remarks>
				///
				/// <returns>	The number gp layer parameters. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual int GetNumGPLayerParameters();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets the parameters for each optimization iteration. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <param name="param">	The parameter. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void SetParameters(const af::array& param) override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets the parameters for each optimization iteration. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <param name="param">	The parameter. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual af::array GetParameters() override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Updates the parameters. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void UpdateParameters() override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets number of GP layers. </summary>
				/// 
				/// <remarks>	Hmetal T, 09/07/2019. </remarks>
				///
				/// <returns>	The number layers. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual int GetNumLayers();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets vector of GP layers. </summary>
				///
				/// <remarks>	HmetalT, 09/07/2019. </remarks>
				///
				/// <returns>	null if it fails, else the gp layers. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual std::vector<SparseGPBaseLayer<Scalar>*> GetGPLayers();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets fixation for hyperparameters. </summary>
				///
				/// <remarks>	Hmetal T, 16/12/2019. </remarks>
				///
				/// <param name="isfixed">	True if isfixed. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void FixKernelParameters(bool isfixed);


			protected:
				virtual void PredictForwardMM(int numTimeSamples, af::array& my, af::array& vy,
					af::array* mx = nullptr, af::array* vx = nullptr);

				virtual void PredictForwardMC(int numTimeSamples, af::array& my, af::array& vy, int numSamples = 200,
					af::array* mx = nullptr);

				int iNumLayer;					//< number of gp layers

				std::vector<int> vNumPseudosPerLayer;
				std::vector<int> vSize;
				std::vector<HiddenLayerDescription> vDescription;

				SparseGPBaseLayer<Scalar>* dynLayer;		//!< sparse Gaussian Process dynamic layer	
				std::vector<SparseGPBaseLayer<Scalar>*> gpEmissLayer; //!< sparse Gaussian Process emission layer

			private:
				friend class AEP::SGPLayer<Scalar>;
				friend class PowerEP::SGPLayer<Scalar>;
				//friend class VFE::SGPLayer/*<Scalar>*/;

				friend class boost::serialization::access;

				template<class Archive>
				void serialize(Archive& ar, unsigned int version)
				{
					ar& boost::serialization::base_object<GPSSBaseModel<Scalar>>(*this);

					//ar& boost::serialization::make_nvp("GPLVMBaseModel", boost::serialization::base_object<GPLVMBaseModel<Scalar>>(*this));

					ar.register_type<AEP::SGPLayer<Scalar>>();
					ar.register_type<PowerEP::SGPLayer<Scalar>>();
					//ar.register_type<VFE::SGPLayer>();

					ar& BOOST_SERIALIZATION_NVP(iNumLayer);
					ar& BOOST_SERIALIZATION_NVP(vNumPseudosPerLayer);
					ar& BOOST_SERIALIZATION_NVP(vSize);
					ar& BOOST_SERIALIZATION_NVP(vDescription);
					ar& BOOST_SERIALIZATION_NVP(dynLayer);
					ar& BOOST_SERIALIZATION_NVP(gpEmissLayer);
				}
			};
		}
	}
}

