/**
File:		MachineLearning/Models/GPModels/FgSparseGPBaseModel.h

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#pragma once

#include <MachineLearning/FgGPBaseModel.h>
#include <MachineLearning/FgSparseGPBaseLayer.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		namespace GPModels
		{
			namespace PowerEP
			{
				template<typename Scalar>
				class SGPLayer;
			}

			namespace AEP
			{
				template<typename Scalar>
				class SGPLayer;
			}

			namespace VFE
			{
				//template<Scalar>
				class SGPLayer;
			}

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary> Base class for all sparse GP models. </summary>
			///
			/// <remarks>	
			/// 	Sparse approximations are used for larger
			/// 	data sets to reduce memory size and computational complexity.  This is
			/// 	done by introducing a subset of inducing points or pseudo inputs to approximate
			/// 	the full set. The inversion of the kernel matrix depends only on those points
			/// 	and reduces the computationsl complexity from O(N^3) to O(k^2N), where k is the
			/// 	number of inducing points and N the length of the data set.
			///
			/// 	For more information see:
			/// 	<a href="https://pdfs.semanticscholar.org/99f9/3283e415ae21bd42a90031cd3972f3bfbc9d.pdf" target="_blank">
			/// 	
			/// 	, 21.03.2018. 
			/// </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			template<typename Scalar>
			class NE_IMPEXP SparseGPBaseModel : public GPBaseModel<Scalar>
			{
			public:

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Constructor. </summary>
				///
				/// <remarks>	, 21.03.2018. </remarks>
				///
				/// <param name="X">		  	[in,out] The training inputs. </param>
				/// <param name="Y">		  	[in,out] The training data. </param>
				/// <param name="numInducing">	Number of inducings points. </param>
				/// <param name="lType">	  	The likelihood or objective type. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				SparseGPBaseModel(const af::array& Y, const af::array& X, int numInducing = 200, LogLikType lType = LogLikType::Gaussian);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Default constructor. </summary>
				///
				/// <remarks>	, 26.03.2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				SparseGPBaseModel();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Destructor. </summary>
				///
				/// <remarks>	, 15.05.2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual ~SparseGPBaseModel();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Predict noise free functions values \f$\mathbf{F}_*\f$. </summary>
				///
				/// <remarks>	Hmetal T, 05/05/2020. </remarks>
				///
				/// <param name="testInputs">	The test inputs. </param>
				/// <param name="mf">		 	[in,out] mean of function values. </param>
				/// <param name="vf">		 	[in,out] The variance of function values. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void PredictF(const af::array& testInputs, af::array& mf, af::array& vf) override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Generate function samples from posterior. </summary>
				///
				/// <remarks>	Hmetal T, 18/06/2019. </remarks>
				///
				/// <param name="outFunctions">	[in,out] The out functions. </param>
				/// <param name="inputs">	   	The inputs. </param>
				/// <param name="numSamples">  	Number of samples. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void SampleY(const af::array inputs, int numSamples, af::array& outFunctions) override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets training inputs X. </summary>
				///
				/// <remarks>	, 27.03.2018. </remarks>
				///
				/// <returns>	The training inputs. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				af::array GetTrainingInputs();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets training inputs X. </summary>
				///
				/// <remarks>	, 27.03.2018. </remarks>
				///
				/// <param name="inputs">	[in,out] The inputs. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				void SetTrainingInputs(af::array& inputs);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets pseudo inputs. </summary>
				///
				/// <remarks>	Hmetal T, 17/06/2019. </remarks>
				///
				/// <returns>	The pseudo inputs. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				af::array GetPseudoInputs();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Initializes the model. </summary>
				///
				/// <remarks>	Hmetal T, 29.11.2017. </remarks>
				///
				/// <returns>	true if it succeeds, false if it fails. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual bool Init() override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets number of parameters. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <returns>	The number parameters. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual int GetNumParameters() override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets the parameters for each optimization iteration. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <param name="param">	The parameter. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void SetParameters(const af::array& param) override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets the parameters for each optimization iteration. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <param name="param">	The parameter. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual af::array GetParameters() override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Updates the parameters. </summary>
				///
				/// <remarks>	Hmetal T, 23/03/2020. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void UpdateParameters() override;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets fixation for hyperparameters. </summary>
				///
				/// <remarks>	Hmetal T, 16/12/2019. </remarks>
				///
				/// <param name="isfixed">	True if isfixed. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void FixKernelParameters(bool isfixed);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Set fixation for inducing inputs. </summary>
				///
				/// <remarks>	Hmetal T, 16/12/2019. </remarks>
				///
				/// <param name="isfixed">	True if isfixed. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void FixInducing(bool isfixed);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets the gp layer. </summary>
				///
				/// <remarks>	Hmetal T, 18/12/2019. </remarks>
				///
				/// <returns>	Null if it fails, else the gp layer. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				SparseGPBaseLayer<Scalar>* GetGPLayer();

			protected:
				int ik;			//!< number of inducing inputs
				int iq;			//!< latent dimension
				af::array afX;	//!< training inputs

				SparseGPBaseLayer<Scalar>* gpLayer; //!< gp layer

			private:
				friend class AEP::SGPLayer<Scalar>;
				friend class PowerEP::SGPLayer<Scalar>;
				//friend class VFE::SGPLayer/*<Scalar>*/;

				friend class boost::serialization::access;

				template<class Archive>
				void serialize(Archive& ar, unsigned int version)
				{
					ar& boost::serialization::base_object<GPBaseModel<Scalar>>(*this);

					//ar& boost::serialization::make_nvp("GPBaseModel", boost::serialization::base_object<GPBaseModel<Scalar>>(*this));

					ar.register_type<AEP::SGPLayer<Scalar>>();
					ar.register_type<PowerEP::SGPLayer<Scalar>>();
					/*ar.register_type<VFE::SGPLayer>();*/

					ar& BOOST_SERIALIZATION_NVP(ik);
					ar& BOOST_SERIALIZATION_NVP(iq);
					ar& BOOST_SERIALIZATION_NVP(afX);
					ar& BOOST_SERIALIZATION_NVP(gpLayer);
				}
			};
		}
	}
}
