/**
File:		MachineLearning/Kernel/FgLinearKernel.h

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#pragma once

#include <MachineLearning/FgIKernel.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		////////////////////////////////////////////////////////////////////////////////////////////////////
		/// <summary>	Linear kernel function. </summary>
		///
		/// <remarks>
		/// 	Kernel function constructed as follows:
		/// 	
		/// 	\f[k(\mathbf{x}, \mathbf{x}') = \gamma \mathbf{x}^T \mathbf{x}'.\f]
		/// 		
		/// 	Note: GP's constructed with a linear kernel are equivalent with the dual version of 
		/// 	probabilistic PCA.
		/// 			
		/// 	 Admin, 5/24/2017. 
		/// </remarks>
		////////////////////////////////////////////////////////////////////////////////////////////////////
		template<typename Scalar>
		class NE_IMPEXP LinearKernel : public IKernel<Scalar>
		{
		public:
			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Default constructor. </summary>
			///
			/// <remarks>	 Admin, 5/24/2017. </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			LinearKernel();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Constructor. </summary>
			///
			/// <remarks>	Hmetal T, 16/10/2020. </remarks>
			///
			/// <param name="numdims">	(Optional) The numdims. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			LinearKernel(int numdims);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Destructor. </summary>
			///
			/// <remarks>	 Admin, 5/24/2017. </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			~LinearKernel();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Computes the kernel matrix of the kernel. </summary>
			///
			/// <remarks>	 Admin, 5/23/2017. </remarks>
			///
			/// <param name="inX1">	 	[in,out] First n times q matrix of latent points. </param>
			/// <param name="inX2">	 	[in,out] Second m times q matrix of latent points (X'). </param>
			/// <param name="outMatrix">	[in,out] Resulting kernel matrix.< / param> </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void ComputeKernelMatrix(const af::array& inX1, const af::array& inX2, af::array& outMatrix);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Calculates only diagonal elements of K. </summary>
			///
			/// <remarks>	 Admin, 5/24/2017. </remarks>
			///
			/// <param name="inX">		  	[in,out] Nxq matrix X. </param>
			/// <param name="outDiagonal">	[in,out] The out diagonal. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void ComputeDiagonal(const af::array& inX, af::array& outDiagonal);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Computes dL/dX for full fit GP. </summary>
			///
			/// <remarks>	 Admin, 5/23/2017. </remarks>
			///
			/// <returns>
			/// 	The method returns the gradients of the latent points of type ILArray&lt;Scalar&gt;.
			/// </returns>
			///
			/// <param name="inX">	   	[in,out] Nxq Matrix of latent points. </param>
			/// <param name="indL_dK">
			/// 	[in,out] Derivative of the loglikelihood w.r.t kernel matrix K.
			/// </param>
			/// <param name="outdL_dX">	[in,out] Derivative of the loglikelihood w.r.t latent points. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void LogLikGradientX(const af::array& inX, const af::array& indL_dK, af::array& outdL_dX);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Computes dL / dX and dL / dXu for sparse approximation GP. </summary>
			///
			/// <remarks>	 Admin, 5/23/2017. </remarks>
			///
			/// <param name="inXu">			[in,out] kxq Matrix of latent subset points. </param>
			/// <param name="indL_dKuu">	[in,out] Derivative of the loglikelihood w.r.t subset kernel matrix Kuu. </param>
			/// <param name="inX">			[in,out] Nxq Matrix of latent points. </param>
			/// <param name="indL_dKuf">	[in,out] Derivative of the loglikelihood w.r.t the kernel matrix K. </param>
			/// <param name="outdL_dXu">	[in,out] Derivative of the loglikelihood w.r.t latent subset points Xu. </param>
			/// <param name="outdL_dX"> 	[in,out] Derivative of the loglikelihood w.r.t latent points X. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void LogLikGradientX(const af::array& inXu, const af::array& indL_dKuu, const af::array& inX, const af::array& indL_dKuf, af::array& outdL_dXu, af::array& outdL_dX);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Computes dL / dX for sparse approximation GP. </summary>
			///
			/// <remarks>	Hmetal T, 25/11/2020. </remarks>
			///
			/// <param name="inX1">	   	The first latent points. </param>
			/// <param name="inX2">	   	The second latent points. </param>
			/// <param name="indL_dK"> 	The derivative dL / dK. </param>
			/// <param name="outdL_dX">	[in,out] The derivative dL / dX. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void LogLikGradientX(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dX) override;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Computes the gradient of the kernel parameters. </summary>
			///
			/// <remarks>	 Admin, 5/23/2017. </remarks>
			///
			/// <param name="indL_dK">
			/// 	[in,out] Derivative of the loglikelihood w.r.t the kernel matrix K.
			/// </param>
			/// <param name="outdL_dParam">	[in,out] Gradient of kernel parameters. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void LogLikGradientParam(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dParam);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Computes the gradient of LL w.r.t. the kernel parameters. </summary>
			///
			/// <remarks>	 Admin, 5/23/2017. </remarks>
			///
			/// <param name="indL_dK">
			/// 	[in,out] Derivative of the loglikelihood w.r.t the kernel matrix K.
			/// </param>
			/// <param name="outdL_dParam">	[in,out] Gradient of kernel parameters. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void LogLikGradientParam(const af::array& inX1, const af::array& inX2, const af::array& indL_dK,
				af::array& outdL_dParam, const af::array* dlogZ_dv) override;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Computes dK/dX. </summary>
			///
			/// <remarks>	 Admin, 5/23/2017. </remarks>
			///
			/// <param name="inX1">	   	[in,out] First n times q matrix of latent points. </param>
			/// <param name="inX2">	   	[in,out] Second n times q matrix of latent points (X'). </param>
			/// <param name="q">	   	The latent dimension to process. </param>
			/// <param name="outdK_dX">	[in,out] dK/dX. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void GradX(const af::array& inX1, const af::array& inX2, int q, af::array& outdK_dX);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Derivative of diagonal elemts of K w.r.t X. </summary>
			///
			/// <remarks>	 Admin, 5/23/2017. </remarks>
			///
			/// <param name="inX">		   	[in,out] The in x coordinate. </param>
			/// <param name="outDiagdK_dX">	[in,out] Derivative of diagonal elemts of K w.r.t X. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void DiagGradX(const af::array& inX, af::array& outDiagdK_dX);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Derivative of diagonal elemts of K w.r.t kernel parameters. </summary>
			///
			/// <remarks>	 Admin, 5/23/2017. </remarks>
			///
			/// <param name="inX">			   	[in,out] Nxq matrix of Latent points X. </param>
			/// <param name="inCovDiag">	   	[in,out] Diagonal of kernel matrix K. </param>
			/// <param name="outDiagdK_dParam">	[in,out] Derivative of diagonal elemts of K w.r.t kernel parameters. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void DiagGradParam(const af::array& inX, const af::array& inCovDiag, af::array& outDiagdK_dParam);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	
			///		Initializes the parameters based on the median of the 
			///		distances of /f$\mathbf{X}/f$. 
			/// </summary>
			///
			/// <remarks>	Hmetal T, 06/11/2020. </remarks>
			///
			/// <param name="inMedian">	The in median. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void InitParameters(Scalar inMedian) override;

			virtual void LogLikGradientCompundKfu(const af::array& indL_dKfu, const af::array& inX, const af::array& inXu,
				af::array* outdL_dParam, af::array* outdL_dXu, const af::array* dlogZ_dv = nullptr, af::array* outdL_dX = nullptr) override;

			virtual void LogGradientCompoundKuu(const af::array& inXu, const af::array& inCovDiag,
				af::array* outdL_dParam, af::array* outdL_dXu) override;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Sets the parameters. </summary>
			///
			/// <remarks>	, 26.06.2018. </remarks>
			///
			/// <param name="param">	The parameter. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void SetParameters(const af::array& param);

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets the parameters. </summary>
			///
			/// <remarks>	, 26.06.2018. </remarks>
			///
			/// <returns>	The parameters. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			af::array GetParameters();

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Sets log parameters. </summary>
			///
			/// <remarks>	Hmetal T, 15/07/2019. </remarks>
			///
			/// <param name="param">	The parameter. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual void SetLogParameters(const af::array& param) override;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Gets log parameters. </summary>
			///
			/// <remarks>	Hmetal T, 15/07/2019. </remarks>
			///
			/// <returns>	The log parameters. </returns>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			virtual af::array GetLogParameters() override;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// PSI statistics
			////////////////////////////////////////////////////////////////////////////////////////////////////

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	PSI statistics computation. </summary>
			///
			/// <remarks>	Hmetal T, 15/07/2019. </remarks>
			///
			/// <param name="inXu">	  	Inducing points. </param>
			/// <param name="inMu">   	Posterior mean. </param>
			/// <param name="inS">	  	Posterior covariance. </param>
			/// <param name="outPsi0">	[in,out] Psi0. </param>
			/// <param name="outPsi1">	[in,out] Psi1. </param>
			/// <param name="outPsi2">	[in,out] Psi2. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void ComputePsiStatistics(const af::array& inXu, const af::array& inMu, const af::array& inS,
				af::array& outPsi0, af::array& outPsi1, af::array& outPsi2) override;

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Parameter and variable derivatives w.r.t. all Psi statistics. </summary>
			///
			/// <remarks>	Hmetal T, 15/07/2019. </remarks>
			///
			/// <param name="inPsi1">	   	Psi1. </param>
			/// <param name="indL_dPsi1">  	Psi1 derivative w.r.t. objective. </param>
			/// <param name="inPsi1">	   	Psi2. </param>
			/// <param name="indL_dPsi1">  	Psi2 derivative w.r.t. objective. </param>
			/// <param name="inXu">	  		Inducing points. </param>
			/// <param name="inMu">		   	Posterior mean. </param>
			/// <param name="inS">		   	Posterior covariance. </param>
			/// <param name="outdL_dParam">	[in,out] Kernel parameter derivatives w.r.t. Psi statistics. </param>
			/// <param name="outdL_dXu">	[in,out] Inducing inputs derivatives w.r.t. Psi statistics. </param>
			/// <param name="outdL_dMu">   	[in,out] Posterior mean derivatives w.r.t. Psi statistics. </param>
			/// <param name="outdL_dS">	   	[in,out] Posterior covariance derivatives w.r.t. Psi statistics. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void PsiDerivatives(const af::array& indL_dPsi0, const af::array& inPsi1, const af::array& indL_dPsi1, const af::array& inPsi2, const af::array& indL_dPsi2,
				const af::array& inXu, const af::array& inMu, const af::array& inS, af::array& outdL_dParam, af::array& outdL_dXu, 
				af::array& outdL_dMu, af::array& outdL_dS, const af::array* dlogZ_dv = nullptr) override;

		protected:

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Parameter and variable derivatives w.r.t. Psi2. </summary>
			///
			/// <remarks>	Hmetal T, 15/07/2019. </remarks>
			///
			/// <param name="inPsi1">	   	Psi2. </param>
			/// <param name="indL_dPsi1">  	Psi2 derivative w.r.t. objective. </param>
			/// <param name="inXu">	  		Inducing points. </param>
			/// <param name="inMu">		   	Posterior mean. </param>
			/// <param name="inS">		   	Posterior covariance. </param>
			/// <param name="outdL_dParam">	[in,out] Kernel parameter derivatives w.r.t. Psi2. </param>
			/// <param name="outdL_dXu">	   	[in,out] Inducing inputs derivatives w.r.t. Psi2. </param>
			/// <param name="outdL_dMu">   	[in,out] Posterior mean derivatives w.r.t. Psi2. </param>
			/// <param name="outdL_dS">	   	[in,out] Posterior covariance derivatives w.r.t. Psi2. </param>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			void Psi2Derivative(const af::array& indL_dPsi2, const af::array& inXu, const af::array& inMu,
				const af::array& inS, af::array& outdL_dParam, af::array& outdL_dXu, af::array& outdL_dMu, af::array& outdL_dS);

		private:
			friend class boost::serialization::access;

			template<class Archive>
			void serialize(Archive& ar, unsigned int version)
			{
				ar & boost::serialization::base_object<IKernel<Scalar>>(*this);
				//ar& boost::serialization::make_nvp("IKernel", boost::serialization::base_object<IKernel<Scalar>>(*this));
				ar& BOOST_SERIALIZATION_NVP(dVariance);
			}

			//af::array afParameter; // []{variance}
			af::array dVariance;
		};
	}
}
