﻿/**
File:		MachineLearning/GPModels/SparseGPModels/AEP/FgAEPSparseGPLayer.h

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#pragma once

#include <MachineLearning/FgSparseGPBaseLayer.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		namespace GPModels
		{
			namespace AEP
			{

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sparse GP layer. </summary>
				///
				/// <remarks>
				/// 	Holds all variables for FITC approximation and PEP. Defines a subset of \f$\mathbf{X}\f$. 
				/// 	Selects \f$k\f$ inducing inputs \f$\mathbf{X_u}\f$, computes the subset kernel matrix 
				/// 	\f$\mathbf{K_{uu}}\f$ and its inverse.
				/// 			
				///		For more information see,
				///		<a href="https://pdfs.semanticscholar.org/99f9/3283e415ae21bd42a90031cd3972f3bfbc9d.pdf" target="_blank">
				///				
				/// 	Hmetal T, 05/05/2018. 
				/// </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				template<typename Scalar>
				class NE_IMPEXP SGPLayer : public SparseGPBaseLayer<Scalar>
				{
				public:

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Constructor. </summary>
					///
					/// <remarks>	, 15.05.2018. </remarks>
					///
					/// <param name="numPoints"> 	Number of points. </param>
					/// <param name="numPseudos">	Number of pseudo inputs. </param>
					/// <param name="outputDim"> 	The output dimension. </param>
					/// <param name="inputDim">  	The input dimension. </param>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					SGPLayer(int numPoints, int numPseudos, int outputDim, int inputDim);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Destructor. </summary>
					///
					/// <remarks>	, 15.05.2018. </remarks>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					virtual ~SGPLayer();

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Forward prediction through cavity. </summary>
					///
					/// <remarks>	
					///		Computes new mean m^{\n}_{f} and covariance V^{\n}_{ff} function. From cavity
					///		distribution.
					///								
					/// 	, 16.05.2018. 
					/// </remarks>
					///
					/// <param name="mout"> 	[in,out] The m^{\n}_{f}. </param>
					/// <param name="vout"> 	[in,out] The V^{\n}_{ff}. </param>
					/// <param name="mx">   	[in,out] The inputs mx. </param>
					/// <param name="vx">   	[in,out] (Optional) If non-null, the variances vx. </param>
					/// <param name="alpha">	(Optional) the alpha, weighning for alpha-divergence. </param>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					void ForwardPredictionCavity(af::array& outMout, af::array& outVout, af::array* psi1out, af::array* psi2out, const af::array& mx, const af::array* vx = nullptr, Scalar alpha = 1.0, PropagationMode mode = PropagationMode::MomentMatching);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Forward prediction random cavity MCMC. </summary>
					///
					/// <remarks>	Hmetal T, 10/07/2019. </remarks>
					///
					/// <param name="mout"> 	[in,out] The mout. </param>
					/// <param name="vout"> 	[in,out] The vout. </param>
					/// <param name="mx">   	The mx. </param>
					/// <param name="vx">   	The vx. </param>
					/// <param name="alpha">	(Optional) the alpha. </param>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					void ForwardPredictionRandomCavityMC(af::array& mout, af::array& vout, af::array& xout, af::array& eps, const af::array& mx, const af::array& vx, Scalar alpha = 1.0);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Back propagation gradients through usual GP regression task. </summary>
					///
					/// <remarks>	Hmetal T, 11/07/2019. </remarks>
					///
					/// <param name="outGrad_hyper">	[in,out] Gradient of hypers. </param>
					///
					/// <param name="m">		  	Cavity prediction mean from current layer. </param>
					/// <param name="v">		  	Cavity prediction variance from current layer. </param>
					/// <param name="dlogZ_dm">   	Cavity prediction mean gradient. </param>
					/// <param name="dlogZ_dv">   	Cavity prediction variance gradient. </param>
					/// <param name="x">		  	The training inputs. </param>
					/// <param name="outGrad_cav">	(Optional) [in,out] Gradient of cavity parameters. </param>
					/// <param name="alpha">	  	(Optional) Ratio. </param>
					///
					/// <returns>	Gradients of hyper parameters. </returns>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					af::array BackpropGradientsReg(const af::array& m, const af::array& v, const af::array& dlogZ_dm, const af::array& dlogZ_dv, const af::array& x,
						std::map<std::string, af::array>* outGrad_cav = nullptr, Scalar alpha = 1.0);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Back propagation gradients through moment matching. </summary>
					///
					/// <remarks>	Hmetal T, 11/07/2019. </remarks>
					///
					/// <param name="m">				Cavity prediction mean from current layer. </param>
					/// <param name="v">				Cavity prediction variance from current layer. </param>
					/// <param name="dlogZ_dm">			Cavity prediction mean gradient. </param>
					/// <param name="dlogZ_dv">			Cavity prediction variance gradient. </param>
					/// <param name="psi1">				Cavity Psi1 statistics. </param>
					/// <param name="psi2">				Cavity Psi2 statistics. </param>
					/// <param name="mx">				Cavity prediction mean from lower layer. </param>
					/// <param name="vx">				Cavity prediction variance from lower layer. </param>
					/// <param name="outGrad_hyper">	[in,out] Gradient of hypers. </param>
					/// <param name="outGrad_cav">  	[in,out] Gradient of cavity parameters. </param>
					/// <param name="alpha">			(Optional) Ratio. </param>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					af::array BackpropGradientsMM(const af::array& m, const af::array& v, const af::array& dlogZ_dm, const af::array& dlogZ_dv,
						const af::array& psi1, const af::array& psi2, const af::array& mx, const af::array& vx, std::map<std::string, af::array>* outGrad_cav, Scalar alpha = 1.0);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Back propagation gradients through MCMC. </summary>
					///
					/// <remarks>	Hmetal T, 11/07/2019. </remarks>
					///
					/// <param name="m">		  	Cavity prediction mean from current layer. </param>
					/// <param name="v">		  	Cavity prediction variance from current layer. </param>
					/// <param name="dlogZ_dm">   	Cavity prediction mean gradient. </param>
					/// <param name="dlogZ_dv">   	Cavity prediction variance gradient. </param>
					/// <param name="x">		  	The training inputs. </param>
					/// <param name="outGrad_cav">	[in,out] (Optional)  Gradient of cavity parameters. </param>
					/// <param name="alpha">	  	(Optional) Ratio. </param>
					///
					/// <returns>	Gradients of hyper parameters. </returns>
					///
					////////////////////////////////////////////////////////////////////////////////////////////////////
					af::array BackpropGradientsMC(const af::array& m, const af::array& v, const af::array& eps, const af::array& dlogZ_dm, const af::array& dlogZ_dv, const af::array& x,
						std::map<std::string, af::array>* outGradInput, Scalar alpha = 1.0);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Computes the weighted sum of the log-partitions of prior, post and cav. </summary>
					///
					/// <remarks>	, 25.06.2018. </remarks>
					///
					/// <param name="alpha">	The alpha. </param>
					///
					/// <returns>	The calculated phi. </returns>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					Scalar ComputePhi(Scalar alpha);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Update step of PowerEP. </summary>
					///
					/// <remarks>	
					/// 			
					/// 	Hmetal T, 08/06/2018. 
					/// </remarks>
					///
					/// <param name="n">	   	[in,out] The indexes to compute. </param>
					/// <param name="grad_cav">	[in,out] The gradient of the cavity functions m_f^{\n} and V_{ff}^{\m}. </param>
					/// <param name="alpha">   	The alpha. </param>
					/// <param name="decay">   	(Optional) the decay. </param>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					void UpdateFactor(af::array& n, std::map<std::string, af::array> grad_cav, Scalar alpha, Scalar decay = 0);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Updates the parameters. </summary>
					///
					/// <remarks>	, 26.06.2018. </remarks>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					virtual void UpdateParameters() override;

				protected:
					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Default constructor. </summary>
					///
					/// <remarks>	Hmetal T, 02/07/2018. </remarks>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					SGPLayer() { }

					void ForwardPredictionDeterministicPost(const af::array& mx, af::array* mout, af::array* vout) override;

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Forward prediction through deterministic cavity. </summary>
					///
					/// <remarks>	
					///		Projection step. Computation of cavity posterior mean and covariance function,
					///		
					///		$$m^{\n}_{f} = K_{fu}K_{uu}^{−1}T^{\n, −1}_{2, u}T^{\n}_{1, u},$$
					///			
					///		$$V{\n}_{ff} = K_{ff} − Q_{ff} + K_{fu}K^{−1}_{uu}T^{\n, −1}_{2, u}K^{−1}_{uu}K_{uf}.$$
					///								
					/// 	, 16.05.2018. 
					/// </remarks>
					///
					/// <param name="mout"> 	[in,out] The m^{\n}_{f}. </param>
					/// <param name="vout"> 	[in,out] The V^{\n}_{ff}. </param>
					/// <param name="idx">		[in,out] indexes of points to be removed from distribution. </param>
					/// <param name="mx">   	[in,out] The inputs mx. </param>
					/// <param name="alpha">	(Optional) the alpha, weighning for alpha-divergence. </param>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					void ForwardPredictionDeterministicCavity(af::array& outMout, af::array& outVout, af::array* kfuOut, const af::array& mx, Scalar alpha = 1.0);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Forward prediction through random cavity. </summary>
					///
					/// <remarks>	, 16.05.2018. </remarks>
					///
					/// <param name="idx">			[in,out] indexes of points to be removed from distribution. </param>
					/// <param name="mout">		 	[in,out] The m^{\n}_{f}. </param>
					/// <param name="vout">		 	[in,out] The V^{\n}_{ff}. </param>
					/// <param name="mx">		 	The inputs mx. </param>
					/// <param name="vx">		 	If non-null, the variances vx. </param>
					/// <param name="mode">		 	Propagation mode. </param>
					/// <param name="parameter6">	(Optional) the alpha, ratio for alpha-divergence. </param>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					void ForwardPredictionRandomCavity(af::array& mout, af::array& vout, af::array* psi1out, af::array* psi2out, const af::array& mx, const af::array& vx, PropagationMode mode, Scalar alpha = 1.0);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Forward prediction random cavity moment matching. </summary>
					///
					/// <remarks>	Hmetal T, 10/07/2019. </remarks>
					///
					/// <param name="mout"> 	[in,out] The mout. </param>
					/// <param name="vout"> 	[in,out] The vout. </param>
					/// <param name="mx">   	The mx. </param>
					/// <param name="vx">   	The vx. </param>
					/// <param name="alpha">	(Optional) the alpha. </param>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					void ForwardPredictionRandomCavityMM(af::array& mout, af::array& vout, af::array* psi1out, af::array* psi2out, const af::array& mx, const af::array& vx, Scalar alpha = 1.0);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Forward prediction through random posterior. </summary>
					///
					/// <remarks>	, 16.05.2018. </remarks>
					///
					/// <param name="mx">  	The inputs mx. </param>
					/// <param name="vx">  	If non-null, the variances vx. </param>
					/// <param name="mout">	[in,out] The m^{\n}_{f}. </param>
					/// <param name="vout">	[in,out] The V^{\n}_{ff}. </param>
					/// <param name="mode">	(Optional) Propagation mode. </param>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					virtual void ForwardPredictionRandomPost(const af::array& mx, const af::array& vx, af::array& mout, af::array& vout, PropagationMode mode = PropagationMode::MomentMatching) override;

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Forward prediction through posterior via Moment Matching. </summary>
					///
					/// <remarks>	Hmetal T, 17/09/2019. </remarks>
					///
					/// <param name="mx">  	The inputs mx. </param>
					/// <param name="vx">  	The input variances vx. </param>
					/// <param name="mout">	[in,out] The mout. </param>
					/// <param name="vout">	[in,out] The vout. </param>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					void ForwardPredictionRandomPostMM(const af::array& mx, const af::array& vx, af::array& mout, af::array& vout);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Calculates the cavity. </summary>
					///
					/// <remarks>
					/// 	Deletion step: The cavity for data point n, q^{\n}(f) ∝ q^∗(f)/t^α_n(u), has a similar
					/// 	form to the posterior, but the natural parameters are modified by deletion,
					/// 	
					/// 		T^\n_{1, u} = T_{1, u} − αT_{1, n}
					/// 	
					/// 			and
					/// 	
					/// 		T^\n_{2, u} = T_{2, u} − αT_{2, n},
					/// 	
					/// 	for yielding new mean \^{\mathbf{m_u}} and covariance function \^{\mathbf{S_u} for
					/// 	cavity distribution.
					/// 	
					/// 	, 16.05.2018.
					///// </remarks>
					///
					/// <param name="alpha">		(Optional) the alpha, weighning for alpha-divergence. </param>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					void ComputeCavity(Scalar alpha = 1.0f);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Calculates gradient contributions of cavity distribution. </summary>
					///
					/// <remarks>	, 22.06.2018. </remarks>
					///
					/// <param name="dMucav">	  	[in,out] The gradient mucav. </param>
					/// <param name="dSucav">	  	[in,out] The gradient sucav. </param>
					/// <param name="out_dT1">	  	[in,out] The gradient of natural parameter 1. </param>
					/// <param name="out_dT2">	  	[in,out] The gradient of natural parameter 2. </param>
					/// <param name="out_dInvKuu">	[in,out] The gradient of inverse Kuu. </param>
					/// <param name="alpha">	  	(Optional) the alpha. </param>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					void ComputeCavityGradientU(af::array& dMucav, af::array& dSucav, af::array& out_dT1, af::array& out_dT2, af::array& out_dInvKuu, Scalar alpha = 1.0f);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Calculates the gradient contributions of posterior. </summary>
					///
					/// <remarks>	, 25.06.2018. </remarks>
					///
					/// <param name="dMu">		  	[in,out] The gradient mu. </param>
					/// <param name="dSu">		  	[in,out] The gradient su. </param>
					/// <param name="out_dT1">	  	[in,out] The gradient of natural parameter 1. </param>
					/// <param name="out_dT2">	  	[in,out] The gradient of natural parameter 2. </param>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					void ComputePosteriorGradientU(af::array& dMu, af::array& dSu, af::array& out_dT1, af::array& out_dT2, af::array& out_dInvKuu);

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Calculates energy contribution phi prior. </summary>
					///
					/// <remarks>	, 26.06.2018. </remarks>
					///
					/// <returns>	The calculated phi prior. </returns>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					Scalar ComputePhiPrior();

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Calculates energy contribution phi posterior. </summary>
					///
					/// <remarks>	, 26.06.2018. </remarks>
					///
					/// <returns>	The calculated phi posterior. </returns>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					Scalar ComputePhiPosterior();

					////////////////////////////////////////////////////////////////////////////////////////////////////
					/// <summary>	Calculates energy contribution phi cavity. </summary>
					///
					/// <remarks>	, 26.06.2018. </remarks>
					///
					/// <returns>	The calculated phi cavity. </returns>
					////////////////////////////////////////////////////////////////////////////////////////////////////
					Scalar ComputePhiCavity();

					// AEP variables
					af::array afGamma;
					af::array afBeta;
					af::array afGammaHat;
					af::array afBetaHat;
					af::array afSuMuMu;
					af::array afBetaStochastic;
					af::array afBetaHatStochastic;

					// Cavity temp variables
					af::array afMuHat;			// mean^{\n} of q(u)
					af::array afSuHat;			// covariance^{\n} of q(u)
					af::array afInvSuHat;		// T_{2,u}
					af::array afInvSuMuHat;		// T_{1,u}
					af::array afSuMuMuHat;

				private:
					friend class boost::serialization::access;
					template<typename> friend class SDGPR;

					template<class Archive>
					void serialize(Archive& ar, unsigned int version)
					{
						ar& boost::serialization::base_object<SparseGPBaseLayer<Scalar>>(*this);
						//ar& boost::serialization::make_nvp("SparseGPBaseLayer", boost::serialization::base_object<SparseGPBaseLayer<Scalar>>(*this));
						/*ar& BOOST_SERIALIZATION_NVP(afGamma);
						ar& BOOST_SERIALIZATION_NVP(afBeta);
						ar& BOOST_SERIALIZATION_NVP(afGammaHat);
						ar& BOOST_SERIALIZATION_NVP(afBetaHat);
						ar& BOOST_SERIALIZATION_NVP(afMuHat);
						ar& BOOST_SERIALIZATION_NVP(afSuHat);
						ar& BOOST_SERIALIZATION_NVP(afInvSuHat);
						ar& BOOST_SERIALIZATION_NVP(afInvSuMuHat);
						ar& BOOST_SERIALIZATION_NVP(afSuMuMuHat);
						ar& BOOST_SERIALIZATION_NVP(afSuMuMu);
						ar& BOOST_SERIALIZATION_NVP(afBetaStochastic);
						ar& BOOST_SERIALIZATION_NVP(afBetaHatStochastic);*/
					}
				};
			}
		}
	}
}

