/**
File:		MachineLearning/Kernel/FgIKernel.h

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#pragma once

#include <NeMachineLearningLib.h>
#include <MachineLearning/CommonUtil.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		enum KernelType
		{
			eLinearKernel = 0,
			eRBFKernel = 1,
			eCompoundKernel = 2,
			eTensorKernel = 3,
			eWhiteKernel = 4,
			eRBFAccelerationKernel = 5,
			eLinearAccelerationKernel = 6,
			eARDKernel = 7,
			eStyleKernel = 8,
			eInterKernel = 9
		};

		template<typename Scalar>
		class NE_IMPEXP IKernel
		{
		public:

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Constructor. </summary>
			///
			/// <remarks>	 Admin, 5/26/2017. </remarks>
			///
			/// <param name="type">	   	The type. </param>
			/// <param name="numParam">	Number of parameters. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			IKernel(KernelType type, int numParameters);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Destructor. </summary>
			///
			/// <remarks>	 Admin, 5/26/2017. </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual ~IKernel();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets number of parametes of the kernel. </summary>
			///
			/// <remarks>	 Admin, 5/24/2017. </remarks>
			///
			/// <returns>	The number parameter. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual int GetNumParameter();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets kernel type. </summary>
			///
			/// <remarks>	 Admin, 5/24/2017. </remarks>
			///
			/// <returns>	The kernel type. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual KernelType GetKernelType();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Computes the kernel matrix of the kernel. </summary>
			///
			/// <remarks>	 Admin, 5/23/2017. </remarks>
			///
			/// <param name="inX1">	 	[in,out] First n times q matrix of latent points. </param>
			/// <param name="inX2">	 	[in,out] Second m times q matrix of latent points (X'). </param>
			/// <param name="outMatrix">	[in,out] Resulting kernel matrix.< / param> </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void ComputeKernelMatrix(const af::array& inX1, const af::array& inX2, af::array& outMatrix) = 0;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Calculates only diagonal elements of K. </summary>
			///
			/// <remarks>	 Admin, 5/24/2017. </remarks>
			///
			/// <param name="inX">		  	[in,out] Nxq matrix X. </param>
			/// <param name="outDiagonal">	[in,out] The out diagonal. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void ComputeDiagonal(const af::array& inX, af::array& outDiagonal) = 0;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Computes dL/dX for full fit GP. </summary>
			///
			/// <remarks>	 Admin, 5/23/2017. </remarks>
			///
			/// <returns>
			/// 	The method returns the gradients of the latent points of type ILArray&lt;double&gt;.
			/// </returns>
			///
			/// <param name="inX">	   	[in,out] Nxq Matrix of latent points. </param>
			/// <param name="indL_dK">
			/// 	[in,out] Derivative of the loglikelihood w.r.t kernel matrix K.
			/// </param>
			/// <param name="outdL_dX">	[in,out] Derivative of the loglikelihood w.r.t latent points. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void LogLikGradientX(const af::array& inX, const af::array& indL_dK, af::array& outdL_dX) = 0;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Computes dL / dX and dL / dXu for sparse approximation GP. </summary>
			///
			/// <remarks>	 Admin, 5/23/2017. </remarks>
			///
			/// <param name="inXu">			[in,out] kxq Matrix of latent subset points. </param>
			/// <param name="indL_dKuu">	[in,out] Derivative of the loglikelihood w.r.t subset kernel matrix Kuu. </param>
			/// <param name="inX">			[in,out] Nxq Matrix of latent points. </param>
			/// <param name="indL_dKuf">	[in,out] Derivative of the loglikelihood w.r.t the kernel matrix K. </param>
			/// <param name="outdL_dXu">	[in,out] Derivative of the loglikelihood w.r.t latent subset points Xu. </param>
			/// <param name="outdL_dX"> 	[in,out] Derivative of the loglikelihood w.r.t latent points X. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void LogLikGradientX(const af::array& inXu, const af::array& indL_dKuu, const af::array& inX, 
				const af::array& indL_dKuf, af::array& outdL_dXu, af::array& outdL_dX) = 0;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Computes dL / dX for sparse approximation GP. </summary>
			///
			/// <remarks>	Hmetal T, 25/11/2020. </remarks>
			///
			/// <param name="inX1">	   	The first latent points. </param>
			/// <param name="inX2">	   	The second latent points. </param>
			/// <param name="indL_dK"> 	The derivative dL / dK. </param>
			/// <param name="outdL_dX">	[in,out] The derivative dL / dX. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void LogLikGradientX(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dX) { };

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Computes the gradient of LL w.r.t. the kernel parameters. </summary>
			///
			/// <remarks>	 Admin, 5/23/2017. </remarks>
			///
			/// <param name="indL_dK">
			/// 	[in,out] Derivative of the loglikelihood w.r.t the kernel matrix K.
			/// </param>
			/// <param name="outdL_dParam">	[in,out] Gradient of kernel parameters. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void LogLikGradientParam(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dParam) = 0;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Computes the gradient of LL w.r.t. the kernel parameters. </summary>
			///
			/// <remarks>	 Admin, 5/23/2017. </remarks>
			///
			/// <param name="indL_dK">
			/// 	[in,out] Derivative of the loglikelihood w.r.t the kernel matrix K.
			/// </param>
			/// <param name="outdL_dParam">	[in,out] Gradient of kernel parameters. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void LogLikGradientParam(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dParam, const af::array* dlogZ_dv) { };

			virtual void LogLikGradientCompundKfu(const af::array& indL_dKfu, const af::array& inX, const af::array& inXu, 
				af::array* outdL_dParam, af::array* outdL_dXu, const af::array* dlogZ_dv = nullptr, af::array* outdL_dX = nullptr) { };

			virtual void LogGradientCompoundKuu(const af::array& inXu, const af::array& inCovDiag,
				af::array* outdL_dParam, af::array* outdL_dXu) { };

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Computes dK/dX. </summary>
			///
			/// <remarks>	 Admin, 5/23/2017. </remarks>
			///
			/// <param name="inX1">	   	[in,out] First n times q matrix of latent points. </param>
			/// <param name="inX2">	   	[in,out] Second n times q matrix of latent points (X'). </param>
			/// <param name="q">	   	The latent dimension to process. </param>
			/// <param name="outdK_dX">	[in,out] dK/dX. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void GradX(const af::array& inX1, const af::array& inX2, int q, af::array& outdK_dX) = 0;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Derivative of diagonal elemts of K w.r.t X. </summary>
			///
			/// <remarks>	 Admin, 5/23/2017. </remarks>
			///
			/// <param name="inX">		   	[in,out] The in x coordinate. </param>
			/// <param name="outDiagdK_dX">	[in,out] Derivative of diagonal elemts of K w.r.t X. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void DiagGradX(const af::array& inX, af::array& outDiagdK_dX) = 0;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Derivative of diagonal elemts of K w.r.t kernel parameters. </summary>
			///
			/// <remarks>	 Admin, 5/23/2017. </remarks>
			///
			/// <param name="inX">			   	[in,out] Nxq matrix of Latent points X. </param>
			/// <param name="inCovDiag">	   	[in,out] Diagonal of kernel matrix K. </param>
			/// <param name="outDiagdK_dParam">	[in,out] Derivative of diagonal elemts of K w.r.t kernel parameters. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void DiagGradParam(const af::array& inX, const af::array& inCovDiag, af::array& outDiagdK_dParam) = 0;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Sets the parameters. </summary>
			///
			/// <remarks>	, 26.06.2018. </remarks>
			///
			/// <param name="param">	The parameter. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void SetParameters(const af::array& param) = 0;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Sets log parameters. </summary>
			///
			/// <remarks>	Hmetal T, 06/11/2020. </remarks>
			///
			/// <param name="param">	The parameter. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void SetLogParameters(const af::array& param) { };

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets the parameters. </summary>
			///
			/// <remarks>	, 26.06.2018. </remarks>
			///
			/// <returns>	The parameters. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual af::array GetParameters() = 0;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets log parameters. </summary>
			///
			/// <remarks>	Hmetal T, 06/11/2020. </remarks>
			///
			/// <returns>	The log parameters. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual af::array GetLogParameters() { return 0; };

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	
			///		Initializes the parameters based on the median of the 
			///		distances of /f$\mathbf{X}/f$. 
			/// </summary>
			///
			/// <remarks>	Hmetal T, 06/11/2020. </remarks>
			///
			/// <param name="inMedian">	The in median. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void InitParameters(Scalar inMedian) { };

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// PSI statistics
			////////////////////////////////////////////////////////////////////////////////////////////////////

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	PSI statistics computation. </summary>
			///
			/// <remarks>	
			/// 	An approximated psi-statistics based on Gauss-Hermite Quadrature		
			/// 			
			/// 	HmetalT, 15/07/2019. 
			/// </remarks>
			///
			/// <param name="inZ">	  	Log normalizer. </param>
			/// <param name="inMu">   	Posterior mean. </param>
			/// <param name="inS">	  	Posterior covariance. </param>
			/// <param name="outPsi0">	[in,out] Psi0. </param>
			/// <param name="outPsi1">	[in,out] Psi1. </param>
			/// <param name="outPsi2">	[in,out] Psi2. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void ComputePsiStatistics(const af::array& inXu, const af::array& inMu, const af::array& inS,
				af::array& outPsi0, af::array& outPsi1, af::array& outPsi2);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Psi derivatives. </summary>
			///
			/// <remarks>	Hmetal T, 10/11/2020. </remarks>
			///
			/// <param name="inPsi1">	   	The first in psi. </param>
			/// <param name="indL_dPsi1">  	The first ind l d psi. </param>
			/// <param name="inPsi2">	   	The second in psi. </param>
			/// <param name="indL_dPsi2">  	The second ind l d psi. </param>
			/// <param name="inXu">		   	The in xu. </param>
			/// <param name="inMu">		   	The in mu. </param>
			/// <param name="inS">		   	The in s. </param>
			/// <param name="outdL_dParam">	[in,out] The outd l d parameter. </param>
			/// <param name="outdL_dXu">   	[in,out] The outd l d xu. </param>
			/// <param name="outdL_dMu">   	[in,out] The outd l d mu. </param>
			/// <param name="outdL_dS">	   	[in,out] The outd l d s. </param>
			/// <param name="dlogZ_dv">	   	(Optional) The dlog z coordinate dv. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void PsiDerivatives(const af::array& indL_dPsi0, const af::array& inPsi1, const af::array& indL_dPsi1, const af::array& inPsi2, const af::array& indL_dPsi2,
				const af::array& inXu, const af::array& inMu, const af::array& inS, af::array& outdL_dParam, af::array& outdL_dXu, 
				af::array& outdL_dMu, af::array& outdL_dS, const af::array* dlogZ_dv = nullptr);

			/*virtual void Psi1Derivative(const af::array& inPsi1, const af::array& indL_dPsi1, const af::array& inXu, const af::array& inMu,
				const af::array& inS, af::array& outdL_dParam, af::array& outdL_dZ, af::array& outdL_dMu, af::array& outdL_dS) = 0;*/

		protected:
			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Default constructor. </summary>
			///
			/// <remarks>	HmetalT, 02/07/2018. </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			IKernel() { }

			KernelType eType;
			int iNumParam;

			af::dtype m_dType;

			// Gauss-Hermite parameters
			Scalar sDegree;
			af::array afGHx;
			af::array afGHw;
			af::array afXs;
			bool bCacheK;

		private:
			friend class boost::serialization::access;

			template<class Archive>
			void serialize(Archive& ar, unsigned int version)
			{
				ar& BOOST_SERIALIZATION_NVP(eType);
				ar& BOOST_SERIALIZATION_NVP(iNumParam);
				ar& BOOST_SERIALIZATION_NVP(m_dType);
			}
		};
	}
}
