/**
File:		MachineLearning/GPModels/Models/GPModels/FgAEPSparseGPSSM.h

Author:		
Email:		
Site:       

Copyright (c) 2020 . All rights reserved.
*/

#pragma once

#include <MachineLearning/FgSparseGPSSMBaseModel.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		namespace GPModels
		{
			namespace AEP
			{
				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary> Sparse GPSSM via Approximated Expectation Propagation (AEP). </summary>
				///
				/// <remarks>
				///		<para>
				/// 		Instead of taking one Gaussian portion out to form the cavity, we take out a
				/// 		fraction defined by the parameter \f$\alpha\f$, which can also be seen as a ratio parameter
				/// 		between VFE and PowerEp with FITC approximation.
				/// 	</para>
				/// 	<para>	
				/// 		The lack of an additional approximation for the latent function \f$f\f$ in this EP 
				/// 		scheme results in a prohibitive computational complexity of \f$O(T^3)\f$. The approach 
				/// 		proposed in this class, in contrast, employs Power EP to provide approximate Bayesian 
				/// 		estimates for both \f$f\f$ and \f$\mathbf{x}\f$ simultaneously in a computationally 
				/// 		and analytically tractable manner. Importantly, Power EP offers a flexible approximate 
				/// 		inference framework which has EP and structured variational inference (VI) as special cases.
				///		</para>
				/// 	<para>
				///			References:
				///			<list type="bullet">
				///			<item>
				///			  	  <description><a href="http://mlg.eng.cam.ac.uk/thang/docs/papers/thesis-thang.pdf" target="_blank">
				///					 Bui, T. D. (2018). Efficient Deterministic Approximate Bayesian Inference for Gaussian 
				///					 Process models (Doctoral thesis). https://doi.org/10.17863/CAM.20913  </a>
				///			     </description>
				///			  </item>
				///		</para>
				/// 	
				/// 	
				/// 	, 24.11.2019. 
				/// </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				template<typename Scalar>
				class NE_IMPEXP SGPSSM : public SparseGPSSMBaseModel<Scalar>
				{
				public:
					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Constructor. </summary>
					///
					/// <remarks>	Hmetal T, 04/05/2020. </remarks>
					///
					/// <param name="Y">			  	Observation data. </param>
					/// <param name="latentDimension">	The latent dimension. </param>
					/// <param name="numInducing">	  	(Optional) Number of inducings inputs. </param>
					/// <param name="priorMean">	  	(Optional) The prior mean. </param>
					/// <param name="priorVariance">  	(Optional) The prior variance. </param>
					/// <param name="xControl">		  	[in,out] (Optional) The control. </param>
					/// <param name="lType">		  	(Optional) The likelihood type. </param>
					/// <param name="GPemission">	  	(Optional) True to non-linear emission function. </param>
					/// <param name="controlToEmiss"> 	(Optional) True to control to emiss. </param>
					/// <param name="emethod">		  	(Optional) The embed method. </param>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					SGPSSM(const af::array& Y, int latentDimension, int numInducing = 200, Scalar alpha = 1.0, Scalar priorMean = 0.0,
						Scalar priorVariance = 1.0, af::array xControl = af::array(),
						PropagationMode probMode = PropagationMode::MonteCarlo, LogLikType lType = LogLikType::Gaussian,
						bool GPemission = false, bool controlToEmiss = false, XInit emethod = XInit::pca);

					virtual ~SGPSSM();

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Cost function the given parameter inputs. </summary>
					///
					/// <remarks>	Hmetal T, 29.11.2017. </remarks>
					///
					/// <param name="x">		  	The parameters to be optimized. </param>
					/// <param name="outGradient">	[in,out] The out gradient. </param>
					///
					/// <returns>	A Scalar. </returns>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					virtual Scalar Function(const af::array& x, af::array& outGradient) override;

					SGPSSM();

				protected:

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Computes the cavity distribution. </summary>
					///
					/// <remarks>	Hmetal T, 11/05/2020. </remarks>
					///
					/// <param name="mcav">	[in,out] The mcav. </param>
					/// <param name="vcav">	[in,out] The vcav. </param>
					/// <param name="cav1">	[in,out] The first cav. </param>
					/// <param name="cav2">	[in,out] The second cav. </param>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					virtual void CavityLatents(af::array& mcav, af::array& vcav, af::array& cav1, af::array& cav2);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Calculates the tilted transition. </summary>
					///
					/// <remarks>	Hmetal T, 11/05/2020. </remarks>
					///
					/// <param name="mprob">	   	The mprob. </param>
					/// <param name="vprob">	   	The vprob. </param>
					/// <param name="mcav_t1">	   	The first mcav t. </param>
					/// <param name="vcav_t1">	   	The first vcav t. </param>
					/// <param name="scaleLogZDyn">	The scale log z coordinate dynamic. </param>
					/// <param name="dmProb">	   	[in,out] The dm prob. </param>
					/// <param name="dvProb">	   	[in,out] The dv prob. </param>
					/// <param name="dmt">		   	[in,out] The dmt. </param>
					/// <param name="dvt">		   	[in,out] The dvt. </param>
					///
					/// <returns>	The calculated tilted transition. </returns>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					Scalar ComputeTiltedTransition(const af::array& mprob, const af::array& vprob, const af::array& mcav_t1, const af::array& vcav_t1, 
						Scalar scaleLogZDyn, af::array& dlogZ_dmProb, af::array& dlogZ_dvProb, af::array& dlogZ_dmt, af::array& dlogZ_dvt, Scalar& dlogZ_sn);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Posterior gradient w.r.t \f$\mathbf{X}\f$. </summary>
					///
					/// <remarks>	Hmetal T, 11/05/2020. </remarks>
					///
					/// <returns>	An af::array. </returns>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					virtual af::array PosteriorGradientLatents();

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Cavity gradient w.r.t \f$\mathbf{X}\f$. </summary>
					///
					/// <remarks>	Hmetal T, 11/05/2020. </remarks>
					///
					/// <param name="cav1">	The first cav. </param>
					/// <param name="cav2">	The second cav. </param>
					///
					/// <returns>	An af::array. </returns>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					virtual af::array CavityGradientLatents(const af::array& cav1, const af::array& cav2);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	LogZ gradient w.r.t \f$\mathbf{X}\f$. </summary>
					///
					/// <remarks>	Hmetal T, 11/05/2020. </remarks>
					///
					/// <param name="cav1">		 	The first cav. </param>
					/// <param name="cav2">		 	The second cav. </param>
					/// <param name="dmcav_up">  	The dmcav up. </param>
					/// <param name="dvcav_up">  	The dvcav up. </param>
					/// <param name="dmcav_prev">	The dmcav previous. </param>
					/// <param name="dvcav_prev">	The dvcav previous. </param>
					/// <param name="dmcav_next">	The dmcav next. </param>
					/// <param name="dvcav_next">	The dvcav next. </param>
					///
					/// <returns>	An af::array. </returns>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					virtual af::array LogZGradientLatents(const af::array& cav1, const af::array& cav2, const af::array& dmcav_up, const af::array& dvcav_up,
						const af::array& dmcav_prev, const af::array& dvcav_prev, const af::array& dmcav_next, const af::array& dvcav_next);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Calculates the phi prior. </summary>
					///
					/// <remarks>	Hmetal T, 11/05/2020. </remarks>
					///
					/// <returns>	The calculated phi prior latents. </returns>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					virtual Scalar ComputePhiPriorLatents();

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Calculates the phi cavity. </summary>
					///
					/// <remarks>	Hmetal T, 11/05/2020. </remarks>
					///
					/// <returns>	The calculated phi cavity latents. </returns>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					virtual Scalar  ComputePhiCavityLatents();

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Calculates the phi posterior. </summary>
					///
					/// <remarks>	Hmetal T, 11/05/2020. </remarks>
					///
					/// <returns>	The calculated phi posterior latents. </returns>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					virtual Scalar  ComputePhiPosteriorLatents();

				private:
					Scalar dAlpha;	//!< fraction parameter

					friend class boost::serialization::access;

					template<class Archive>
					void serialize(Archive& ar, unsigned int version)
					{
						ar& boost::serialization::base_object<SparseGPSSMBaseModel<Scalar>>(*this);
						//ar& boost::serialization::make_nvp("SparseGPSSMBaseModel", boost::serialization::base_object<SparseGPSSMBaseModel<Scalar>>(*this));
						ar& BOOST_SERIALIZATION_NVP(dAlpha);
					}
				};

			}
		}
	}
}
/** @example AEP_SGPSSM_Examples.cpp */