/**
File:		MachineLearning/Kernel/FgIKernel.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgIKernel.h>
#include <MachineLearning/FgGaussHermiteQuadrature.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		template class IKernel<float>;
		template class IKernel<double>;

		template<typename Scalar>
		IKernel<Scalar>::IKernel(KernelType type, int numParameters)
			: eType(type), iNumParam(numParameters), m_dType(CommonUtil<Scalar>::CheckDType()), afXs()
		{
			if (eType != KernelType::eARDKernel/* || eType != KernelType::eLinearKernel*/)
			{
				sDegree = 13;
				GaussHermiteQuadrature<Scalar>::Compute(sDegree, afGHx, afGHw);
				afGHx *= sqrt(2.0);
				afGHw *= 1.0 / sqrt(af::Pi);
				bCacheK = false;
			}
		}

		template<typename Scalar>
		IKernel<Scalar>::~IKernel() 
		{ 
		}

		template<typename Scalar>
		int IKernel<Scalar>::GetNumParameter()
		{
			return iNumParam;
		}

		template<typename Scalar>
		KernelType IKernel<Scalar>::GetKernelType()
		{
			return eType;
		}

		////////////////////////////////////////////////////////////////////////////////////////////////////
		/// PSI statistics
		////////////////////////////////////////////////////////////////////////////////////////////////////
		
		template<typename Scalar>
		void IKernel<Scalar>::ComputePsiStatistics(const af::array& inXu, const af::array& inMu, const af::array& inS, af::array& outPsi0, af::array& outPsi1, af::array& outPsi2)
		{
			int iN = inS.dims(0);
			int iq = inS.dims(1);
			int ik = inXu.dims(0);

			af::array sqS, X, kernDiag, Kfu;

			//if (bCacheK) //afXs = self.comp_K(Z, qX)
			/*else*/ sqS = af::sqrt(inS);

			outPsi0 = af::constant(0.0, iN, m_dType);
			outPsi1 = af::constant(0.0, iN, ik, m_dType);
			outPsi2 = af::constant(0.0, ik, ik, iN, m_dType);

			for (auto i = 0; i < sDegree; i++)
			{
				if (bCacheK) X = afXs(i);
				else X = af::tile(afGHx(i), sqS.dims()) * sqS + inMu;

				ComputeDiagonal(X, kernDiag);
				outPsi0 += af::tile(afGHw(i), kernDiag.dims()) * kernDiag;
				ComputeKernelMatrix(X, inXu, Kfu);
				outPsi1 += af::tile(afGHw(i), Kfu.dims()) * Kfu;
				outPsi2 += af::tile(afGHw(i), ik, ik, iN) * af::tile(af::moddims(Kfu.T(), ik, 1, iN), 1, ik) * af::tile(af::moddims(Kfu.T(), 1, ik, iN), ik);
			}
		}

		template<typename Scalar>
		void IKernel<Scalar>::PsiDerivatives(const af::array& indL_dPsi0, const af::array& inPsi1, const af::array& indL_dPsi1, const af::array& inPsi2, const af::array& indL_dPsi2, 
			const af::array& inXu, const af::array& inMu, const af::array& inS, af::array& outdL_dParam, af::array& outdL_dXu, af::array& outdL_dMu, af::array& outdL_dS, const af::array* dlogZ_dv)
		{
			int iN = inS.dims(0);
			int iq = inS.dims(1);
			int ik = inXu.dims(0);

			af::array sqS, X, kernDiag, Kfu, dL_dParam_i, dX, dL_dkfu, dK_dX, dL_dXu_i, dL_dX_i;

			//if (bCacheK) //afXs = self.comp_K(Z, qX)
			/*else*/ sqS = sqrt(inS);

			/*dtheta_old = kern.gradient.copy()
			dtheta = np.zeros_like(kern.gradient)*/

			outdL_dParam = af::constant(0.0, GetNumParameter(), m_dType);

			outdL_dXu = af::constant(0.0, inXu.dims(), m_dType);
			outdL_dMu = af::constant(0.0, inMu.dims(), m_dType);
			outdL_dS = af::constant(0.0, inS.dims(), m_dType);

			for (auto i = 0; i < sDegree; i++)
			{
				if (bCacheK) X = afXs(i);
				else X = tile(afGHx(i), sqS.dims()) * sqS + inMu;

				/*indL_dPsi0_i = indL_dPsi0 * tile(afGHw(i), indL_dPsi0.dims());
				DiagGradParam(X, indL_dPsi0_i, dL_dParam_i);
				outdL_dParam += dL_dParam_i;
				DiagGradX(X, dL_dpsi0_i, dX);*/

				ComputeKernelMatrix(X, inXu, Kfu);
				dL_dkfu = indL_dPsi1 + af::moddims(af::sum(af::tile(af::moddims(Kfu.T(), ik, 1, iN), 1, ik) * (indL_dPsi2 + indL_dPsi2.T())), ik, iN).T();
				dL_dkfu *= af::tile(afGHw(i), dL_dkfu.dims());
				//LogLikGradientParam(X, inXu, dL_dkfu, dL_dParam_i, dlogZ_dv);

				//LogLikGradientX(X, inXu, dL_dkfu, dL_dX_i); // dL_dKuf_dX
				//LogLikGradientX(inXu, X, dL_dkfu.T(), dL_dXu_i); // dL_dKfu_dXu

				LogLikGradientCompundKfu(dL_dkfu, X, inXu, &dL_dParam_i, &dL_dXu_i, dlogZ_dv, &dL_dX_i);

				outdL_dParam += dL_dParam_i;

				outdL_dMu += dL_dX_i;
				outdL_dS += dL_dX_i * af::tile(afGHx(i), dL_dX_i.dims()) / (2.0 * sqS);
				outdL_dXu += dL_dXu_i;
			}
		}
	}
}