/**
File:		MachineLearning/Models/GPModels/FgGPBaseModel.h

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#pragma once

#include <MachineLearning/FgIModel.h>
#include <MachineLearning/FgGaussLikelihoodLayer.h>
#include <MachineLearning/FgProbitLikelihoodLayer.h>


namespace NeuralEngine
{
	namespace MachineLearning
	{
		namespace GPModels
		{
			template<typename Scalar>
			class GaussLikLayer;

			template<typename Scalar>
			class ProbitLikLayer;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	
			/// 	Base class with abstract and basic function definitions. All GP models will be derived 
			/// 	from this class.			
			/// </summary>
			///
			/// <remarks>	HmetalT, 26.10.2017. </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			template<typename Scalar>
			class NE_IMPEXP GPBaseModel : public IModel<Scalar>
			{
			public:

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Constructor. </summary>
				///
				/// <remarks>	, 26.03.2018. </remarks>
				///
				/// <param name="Y">	The data af::array to process. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				GPBaseModel(const af::array& Y, LogLikType lType = LogLikType::Gaussian, ModelType mtype = ModelType::GPR);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Default Constructor. </summary>
				///
				/// <remarks>	Hmetal T, 29.11.2017. </remarks>
				///
				////////////////////////////////////////////////////////////////////////////////////////////////////
				GPBaseModel();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Destructor. </summary>
				///
				/// <remarks>	, 23.04.2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual ~GPBaseModel();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Optimizes the model parameters for best fit. </summary>
				///
				/// <remarks>	Hmetal T, 29.11.2017. </remarks>
				///
				/// <param name="method">			(Optional) the optimization method. </param>
				/// <param name="tol">				(Optional) the tolerance. </param>
				/// <param name="reinit_hypers">	(Optional) true to re hypers. </param>
				/// <param name="maxiter">			(Optional) max iterations. </param>
				/// <param name="mb_size">			(Optional) batch size. </param>
				/// <param name="LSType">			(Optional) linesearch type. </param>
				/// <param name="disp">				(Optional) true to disp. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void Optimise(
					OptimizerType method = L_BFGS,
					Scalar tol = 0.0,
					bool reinit_hypers = true,
					int maxiter = 1000,
					int mb_size = 0,
					LineSearchType lsType = MoreThuente,
					bool disp = true,
					int* cycle = nullptr
				);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Initializes the model. </summary>
				///
				/// <remarks>	Hmetal T, 29.11.2017. </remarks>
				///
				/// <returns>	true if it succeeds, false if it fails. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual bool Init();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Predict noise free functions values \f$\mathbf{F}_*\f$. </summary>
				///
				/// <remarks>	Hmetal T, 05/05/2020. </remarks>
				///
				/// <param name="testInputs">	The test inputs. </param>
				/// <param name="mf">		 	[in,out] mean of function values. </param>
				/// <param name="vf">		 	[in,out] The variance of function values. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void PredictF(const af::array& testInputs, af::array& mf, af::array& vf);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Prediction of test outputs \f$\mathbf{Y}_*\f$. </summary>
				///
				/// <remarks>	, 12.06.2018. </remarks>
				///
				/// <param name="my"> 	[in,out] The posterior mean function. </param>
				/// <param name="vy"> 	[in,out] The posterior covariance function. </param>
				/// <param name="testX">	[in,out] The test inputs. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void PredictY(const af::array& testInputs, af::array& my, af::array& vy);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Generate function samples from posterior. </summary>
				///
				/// <remarks>	Hmetal T, 18/06/2019. </remarks>
				///
				/// <param name="outFunctions">	[in,out] The out functions. </param>
				/// <param name="inputs">	   	The inputs. </param>
				/// <param name="numSamples">  	Number of samples. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void SampleY(const af::array inputs, int numSamples, af::array& outFunctions);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Adds training data to the model. </summary>
				///
				/// <remarks>	Hmetal T, 29.11.2017. </remarks>
				///
				/// <param name="Ytrain">	[in,out] The training data. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void AddData(const af::array Ytrain);


				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets the training data set Y. </summary>
				///
				/// <remarks>	, 26.03.2018. </remarks>
				///
				/// <returns>	The training data. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				af::array GetTrainingData();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets training data Y. </summary>
				///
				/// <remarks>	, 27.03.2018. </remarks>
				///
				/// <param name="data">	[in,out] The data. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				void SetTrainingData(af::array& data);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets number of parameters. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <returns>	The number parameters. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual int GetNumParameters();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets the parameters for each optimization iteration. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <param name="param">	The parameter. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void SetParameters(const af::array& param);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets the parameters for each optimization iteration. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <param name="param">	The parameter. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual af::array GetParameters();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Updates the parameters. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void UpdateParameters();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets the likelihood parameters to be fixed or not for optimization. </summary>
				///
				/// <remarks>	Hmetal T, 17/12/2019. </remarks>
				///
				/// <param name="isfixed">	True if is fixed. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void FixLikelihoodParameters(bool isfixed);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets fixation for hyperparameters. </summary>
				///
				/// <remarks>	Hmetal T, 16/12/2019. </remarks>
				///
				/// <param name="isfixed">	True if isfixed. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				//virtual void FixKernelParameters(bool isfixed) = 0;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets the start indexes array of each sequence. </summary>
				///
				/// <remarks>	Hmetal T, 18/07/2022. </remarks>
				///
				/// <param name="segments">	The segments. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				void SetSegments(af::array segments);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets the start index array for the sequences. </summary>
				///
				/// <remarks>	Hmetal T, 18/07/2022. </remarks>
				///
				/// <returns>	The segments. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				af::array GetSegments();

			protected:

				bool bInit;						//!< check if model is initialized
				af::array afY;					//!< training dataset, mean substracted
				af::array afBias;				//!< the bias
				af::array afSegments;			//!< Index of starting positions for all trials
				

				LikelihoodBaseLayer<Scalar>* likLayer;	//!< liklihood layer
				

				/*std::vector fixed_params
				bool updated = false*/

			private:
				friend class boost::serialization::access;

				friend class GaussLikLayer<Scalar>;
				friend class ProbitLikLayer<Scalar>;
				
				template<class Archive>
				void serialize(Archive& ar, unsigned int version)
				{
					ar& boost::serialization::base_object<IModel<Scalar>>(*this);
					//ar& boost::serialization::make_nvp("IModel", boost::serialization::base_object<IModel<Scalar>>(*this));

					ar.register_type<GaussLikLayer<Scalar>>();
					ar.register_type<ProbitLikLayer<Scalar>>();

					ar& BOOST_SERIALIZATION_NVP(afY);
					ar& BOOST_SERIALIZATION_NVP(afBias);
					ar& BOOST_SERIALIZATION_NVP(afSegments);
					ar& BOOST_SERIALIZATION_NVP(bInit);
					ar& BOOST_SERIALIZATION_NVP(likLayer);
				}
			};
		}
	}
}