/**
File:		MachineLearning/Kernel/FgTensorKernel<Scalar>.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgInterDomainKernel.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		template class InterDomainKernel<float>;
		template class InterDomainKernel<double>;

		template<typename Scalar>
		InterDomainKernel<Scalar>::InterDomainKernel()
			: IKernel<Scalar>(eInterKernel, 0)
		{
		}

		template<typename Scalar>
		InterDomainKernel<Scalar>::~InterDomainKernel()
		{
			if (kSubKernel) delete kSubKernel;
			if (kWindowKernel) delete kWindowKernel;
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::AddSubKernel(IKernel<Scalar>* kernel)
		{
			kSubKernel = kernel;
			iNumParam += kernel->GetNumParameter();
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::AddWindowKernel(IKernel<Scalar>* kernel)
		{
			kWindowKernel = kernel;
			iNumParam += kernel->GetNumParameter();
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::ComputeKernelMatrix(const af::array& inX1, const af::array& inX2, af::array& outMatrix)
		{
			if (inX1.dims(0) == inX2.dims(0))
				kSubKernel->ComputeKernelMatrix(inX1, inX2, outMatrix);
			else
				kWindowKernel->ComputeKernelMatrix(inX1, inX2, outMatrix);
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::ComputeDiagonal(const af::array& inX, af::array& outDiagonal)
		{
			kWindowKernel->ComputeDiagonal(inX, outDiagonal);
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::LogLikGradientX(const af::array& inX, const af::array& indL_dK, af::array& outdL_dX)
		{
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::LogLikGradientX(const af::array& inXu, const af::array& indL_dKuu, const af::array& inX, const af::array& indL_dKuf, af::array& outdL_dXu, af::array& outdL_dX)
		{
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::LogLikGradientX(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dX)
		{
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::LogLikGradientParam(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dParam)
		{
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::LogLikGradientParam(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dParam, const af::array* dlogZ_dv)
		{
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::GradX(const af::array& inX1, const af::array& inX2, int q, af::array& outdK_dX)
		{
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::DiagGradX(const af::array& inX, af::array& outDiagdK_dX)
		{
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::DiagGradParam(const af::array& inX, const af::array& inCovDiag, af::array& outDiagdK_dParam)
		{
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::SetParameters(const af::array& param)
		{
			int istart = 0, iend = kSubKernel->GetNumParameter();
			kSubKernel->SetParameters(param(af::seq(istart, iend - 1)));

			istart = iend; iend = kWindowKernel->GetNumParameter();
			kWindowKernel->SetParameters(param(af::seq(istart, iend - 1)));
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::SetLogParameters(const af::array& param)
		{
			int istart = 0, iend = kSubKernel->GetNumParameter();
			kSubKernel->SetLogParameters(param(af::seq(istart, iend - 1)));

			istart = iend; iend += kWindowKernel->GetNumParameter();
			if (istart != iend)
				kWindowKernel->SetLogParameters(param(af::seq(istart, iend - 1)));
		}

		template<typename Scalar>
		af::array InterDomainKernel<Scalar>::GetLogParameters()
		{
			af::array param = af::constant(0.0f, GetNumParameter(), (m_dType));

			int istart = 0, iend = kSubKernel->GetNumParameter();
			param(af::seq(istart, iend - 1)) = kSubKernel->GetLogParameters();

			istart = iend; iend += kWindowKernel->GetNumParameter();
			if (istart != iend)
				param(af::seq(istart, iend - 1)) = kWindowKernel->GetLogParameters();

			return param;
		}

		template<typename Scalar>
		af::array InterDomainKernel<Scalar>::GetParameters()
		{
			af::array param = af::constant(0.0f, GetNumParameter(), (m_dType));

			int istart = 0, iend = kSubKernel->GetNumParameter();
			param(af::seq(istart, iend - 1)) = kSubKernel->GetLogParameters();

			istart = iend; iend += kWindowKernel->GetNumParameter();
			param(af::seq(istart, iend - 1)) = kWindowKernel->GetLogParameters();

			return param;
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::InitParameters(Scalar inMedian)
		{
			kSubKernel->InitParameters(inMedian);
			kWindowKernel->InitParameters(inMedian);
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::LogLikGradientCompundKfu(const af::array& indL_dKfu, const af::array& inX, const af::array& inXu, af::array* outdL_dParam, af::array* outdL_dXu, const af::array* dlogZ_dv, af::array* outdL_dX)
		{
			*outdL_dParam = af::constant(0.0, iNumParam, (m_dType));

			af::array tmpParam;
			kWindowKernel->LogLikGradientCompundKfu(indL_dKfu, inX, inXu, &tmpParam, outdL_dXu, dlogZ_dv, outdL_dX);

			int istart = kSubKernel->GetNumParameter(); 
			int iend = istart + kWindowKernel->GetNumParameter();
			(*outdL_dParam)(af::seq(istart, iend - 1)) = tmpParam;
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::LogGradientCompoundKuu(const af::array& inXu, const af::array& inCovDiag, af::array* outdL_dParam, af::array* outdL_dXu)
		{
			*outdL_dParam = af::constant(0.0, iNumParam, (m_dType));

			af::array tmpParam;
			kSubKernel->LogGradientCompoundKuu(inXu, inCovDiag, &tmpParam, outdL_dXu);

			int istart = 0;
			int iend = kSubKernel->GetNumParameter();
			(*outdL_dParam)(af::seq(istart, iend - 1)) = tmpParam;
		}

		//////////////////////////////////////////////////////////////////////////////////////////////////////
		///// PSI statistics
		//////////////////////////////////////////////////////////////////////////////////////////////////////
		
		template<typename Scalar>
		void InterDomainKernel<Scalar>::ComputePsiStatistics(const af::array& inXu, const af::array& inMu, const af::array& inS, af::array& outPsi0, af::array& outPsi1, af::array& outPsi2)
		{
			kWindowKernel->ComputePsiStatistics(inXu, inMu, inS, outPsi0, outPsi1, outPsi2);
		}

		template<typename Scalar>
		void InterDomainKernel<Scalar>::PsiDerivatives(const af::array& indL_dPsi0, const af::array& inPsi1, const af::array& indL_dPsi1, const af::array& inPsi2, const af::array& indL_dPsi2, const af::array& inXu, 
			const af::array& inMu, const af::array& inS, af::array& outdL_dParam, af::array& outdL_dXu, af::array& outdL_dMu, af::array& outdL_dS, const af::array* dlogZ_dv)
		{
			outdL_dParam = af::constant(0.0, iNumParam, (m_dType));

			af::array tmpParam;

			kWindowKernel->PsiDerivatives(indL_dPsi0, inPsi1, indL_dPsi1, inPsi2,indL_dPsi2, inXu, inMu, inS, tmpParam, outdL_dXu, outdL_dMu, outdL_dS, dlogZ_dv);

			int istart = kSubKernel->GetNumParameter();
			int iend = istart + kWindowKernel->GetNumParameter();
			outdL_dParam(af::seq(istart, iend - 1)) = tmpParam;
		}
	}
}