/**
File:		MachineLearning/Models/GPModels/FgGPStateSpaceBaseModel.h

Author:		
Email:		
Site:       

Copyright (c) 2020 . All rights reserved.
*/

#pragma once

//#include <MachineLearning/FgINode.h>
#include <MachineLearning/FgGPBaseModel.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		namespace GPModels
		{
			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	
			/// 	Base class with abstract and basic function definitions. All GP state space models will 
			/// 	be derived from this class.			
			/// </summary>
			///
			/// <remarks>	HmetalT, 26.10.2017. </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			template<typename Scalar>
			class NE_IMPEXP GPSSBaseModel : public GPBaseModel<Scalar>, public GPNode<Scalar>
			{
			public:
				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Constructor. </summary>
				///
				/// <remarks>	Hmetal T, 04/05/2020. </remarks>
				///
				/// <param name="Y">			  	Observation data. </param>
				/// <param name="latentDimension">	The latent dimension. </param>
				/// <param name="priorMean">	  	(Optional) The prior mean. </param>
				/// <param name="priorVariance">  	(Optional) The prior variance. </param>
				/// <param name="numInducing">	  	(Optional) Number of inducings inputs. </param>
				/// <param name="lType">		  	(Optional) The likelihood type. </param>
				/// <param name="GPemission">	  	(Optional) True to non-linear emission function. </param>
				/// <param name="emethod">		  	(Optional) The embed method. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				GPSSBaseModel(const af::array& Y, int latentDimension, Scalar priorMean = 0.0,
					Scalar priorVariance = 1.0, af::array& xControl = af::array(), PropagationMode probMode = PropagationMode::MomentMatching, 
					LogLikType lType = LogLikType::Gaussian, bool GPemission = true, bool controlToEmiss = true, XInit emethod = XInit::pca);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Default constructor. </summary>
				///
				/// <remarks>	Hmetal T, 09/05/2022. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				GPSSBaseModel();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Destructor. </summary>
				///
				/// <remarks>	HmetalT, 05/05/2020. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual ~GPSSBaseModel();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Optimizes the model parameters for best fit. </summary>
				///
				/// <remarks>	Hmetal T, 29.11.2017. </remarks>
				///
				/// <param name="method">			(Optional) the optimization method. </param>
				/// <param name="tol">				(Optional) the tolerance. </param>
				/// <param name="reinit_hypers">	(Optional) true to re hypers. </param>
				/// <param name="maxiter">			(Optional) max iterations. </param>
				/// <param name="mb_size">			(Optional) batch size. </param>
				/// <param name="LSType">			(Optional) linesearch type. </param>
				/// <param name="disp">				(Optional) true to disp. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void Optimise(
					OptimizerType method = L_BFGS,
					Scalar tol = 0.0,
					bool reinit_hypers = true,
					int maxiter = 1000,
					int mb_size = 0,
					LineSearchType lsType = MoreThuente,
					bool disp = true,
					int* cycle = nullptr
				) override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Initializes the model. </summary>
				///
				/// <remarks>	Hmetal T, 29.11.2017. </remarks>
				///
				/// <returns>	true if it succeeds, false if it fails. </returns>
					////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual bool Init(af::array& mx);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Get posterior distribution of latent variables /f$\mathbf{X}/f$. </summary>
				///
				/// <remarks>	Hmetal T, 09/12/2019. </remarks>
				///
				/// <param name="index">	Index of selected inputs. </param>
				/// <param name="mx">   	[in,out] The mean. </param>
				/// <param name="vx">   	[in,out] The variance. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void PosteriorLatents(af::array& mx, af::array& vx);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets number of parameters. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <returns>	The number parameters. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual int GetNumParameters() override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets the parameters for each optimization iteration. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <param name="param">	The parameter. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void SetParameters(const af::array& param) override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets the parameters for each optimization iteration. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <param name="param">	The parameter. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual af::array GetParameters() override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Updates the dynamic indexes. </summary>
				///
				/// <remarks>	Hmetal T, 18/05/2022. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				void UpdateDynamicIndexes();

				void GetLatents(af::array& mx, af::array& vx);

				virtual void AddWindowData(af::array data);

			protected:
				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Posterior gradient of latent inputs /f$\mathbf{X}/f$. </summary>
				///
				/// <remarks>	Hmetal T, 09/12/2019. </remarks>
				///
				/// <param name="dmx">	[in,out] The gradient of the mean. </param>
				/// <param name="dvx">	[in,out] The the gradient of the variance. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual af::array PosteriorGradientLatents(const af::array& dmx, const af::array& dvx);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Updates the parameters. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void UpdateParametersInternal();

				int iq;					//!< latent dimension
				int iDControlEmiss;
				int iDControlDyn;

				Scalar dSn;

				bool bGPemission;
				bool bControlToEmiss;

				Scalar dPriorMean;		//!< prior mean
				Scalar dPriorVariance;	//!< prior variance
				Scalar dPriorX1;		//!< prior /f$x_1/f$
				Scalar dPriorX2;		//!< prior /f$x_2/f$

				af::array afFactorX1;		//!< natural parameter factor 1 for latent variable
				af::array afFactorX2;		//!< natural parameter factor 2 for latent variable
				af::array afPosteriorX1;	//!< posterior natural parameter 1 for latent variable
				af::array afPosteriorX2;	//!< posterior natural parameter 2 for latent variable
				af::array afXControl;
				af::array afDynIndexes;

				af::array afPriorX1;		//!< prior /f$x_1/f$
				af::array afPriorX2;		//!< prior /f$x_2/f$
				af::array afPriorX1Cav;		//!< prior /f$x_1/f$
				af::array afPriorX2Cav;		//!< prior /f$x_2/f$
				af::array afPriorMean;		//!< prior mean for hierarchy mode
				af::array afPriorVariance;	//!< prior variance for hierarchy mode
				af::array afPriorMeanCav;		//!< prior mean for hierarchy mode
				af::array afPriorVarianceCav;	//!< prior variance for hierarchy mode
				af::array afGradMean;		//!< prior mean gradient for hiersrchy mode
				af::array afGradVariance;	//!< prior variance gradient for hiersrchy mode
				af::array afGradMeanCav;		//!< prior mean gradient for hiersrchy mode
				af::array afGradVarianceCav;	//!< prior variance gradient for hiersrchy mode

				af::array afLatentGradientX;	//!< top down gradient

				////af::array afLatentGradientX;	//!< top down gradient

				//af::array afIndexes;	//!< indexes of /f$\mathbf{X}/f$ for batch learning
				XInit eEmMethod;

				PropagationMode pMode;
			private:
				friend class boost::serialization::access;

				template<class Archive>
				void serialize(Archive& ar, unsigned int version)
				{
					ar& boost::serialization::base_object<GPBaseModel<Scalar>>(*this);
					//ar& boost::serialization::make_nvp("GPLVMBaseModel", boost::serialization::base_object<GPLVMBaseModel<Scalar>>(*this));

					ar& BOOST_SERIALIZATION_NVP(iq);
					ar& BOOST_SERIALIZATION_NVP(iDControlEmiss);
					ar& BOOST_SERIALIZATION_NVP(iDControlDyn);
					ar& BOOST_SERIALIZATION_NVP(bGPemission);
					ar& BOOST_SERIALIZATION_NVP(bControlToEmiss);
					ar& BOOST_SERIALIZATION_NVP(afXControl);
					ar& BOOST_SERIALIZATION_NVP(afDynIndexes);
					ar& BOOST_SERIALIZATION_NVP(dPriorMean);
					ar& BOOST_SERIALIZATION_NVP(dPriorVariance);
					ar& BOOST_SERIALIZATION_NVP(afFactorX1);
					ar& BOOST_SERIALIZATION_NVP(afFactorX2);
					ar& BOOST_SERIALIZATION_NVP(eEmMethod);
					ar& BOOST_SERIALIZATION_NVP(pMode);
				}
			};
		}
	}
}
