/**
File:		MachineLearning/Kernel/FgCompoundKernel<Scalar>.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgCompoundKernel.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		template class CompoundKernel<float>;
		template class CompoundKernel<double>;

		template<typename Scalar>
		CompoundKernel<Scalar>::CompoundKernel()
			: IKernel<Scalar>(eCompoundKernel, 0)
		{
		}

		template<typename Scalar>
		CompoundKernel<Scalar>::~CompoundKernel()
		{
			for (std::vector<IKernel*>::iterator it = vKernel.begin(); it != vKernel.end(); ++it)
			{
				delete (*it);
			}
			vKernel.clear();
			vIndex.clear();
		}

		template<typename Scalar>
		void CompoundKernel<Scalar>::AddKernel(IKernel<Scalar>* kernel, af::array index)
		{
			vKernel.push_back(kernel);
			vIndex.push_back(index);

			iNumParam += kernel->GetNumParameter();
		}

		template<typename Scalar>
		void CompoundKernel<Scalar>::ComputeKernelMatrix(const af::array& inX1, const af::array& inX2, af::array& outMatrix)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			outMatrix = af::constant(0.0, inX1.dims(0), inX2.dims(0), m_dType);
			af::array tmpKernMtr;
			for (int i = 0; i < vKernel.size(); i++)
			{
				vKernel[i]->ComputeKernelMatrix(inX1(af::span, vIndex[i]), inX2(af::span, vIndex[i]), tmpKernMtr);
				outMatrix += tmpKernMtr;
			}
		}

		template<typename Scalar>
		void CompoundKernel<Scalar>::ComputeDiagonal(const af::array& inX, af::array& outDiagonal)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			outDiagonal = af::constant(0.0, inX.dims(0), 1, m_dType);
			af::array x(m_dType), tmpKernMtr(m_dType);
			for (int i = 0; i < vKernel.size(); i++)
			{
				x = inX(af::span, vIndex[i]);
				vKernel[i]->ComputeDiagonal(x, tmpKernMtr);
				outDiagonal += tmpKernMtr;
			}
		}

		template<typename Scalar>
		void CompoundKernel<Scalar>::LogLikGradientX(const af::array& inX, const af::array& indL_dK, af::array& outdL_dX)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			outdL_dX = af::constant(0.0, inX.dims(), (m_dType));
			af::array x(m_dType), tmpOutdL_dX(m_dType);
			for (int i = 0; i < vKernel.size(); i++)
			{
				x = inX(af::span, vIndex[i]);
				vKernel[i]->LogLikGradientX(x, indL_dK, tmpOutdL_dX);
				outdL_dX += tmpOutdL_dX;
			}
		}

		template<typename Scalar>
		void CompoundKernel<Scalar>::LogLikGradientX(const af::array& inXu, const af::array& indL_dKuu, const af::array& inX, const af::array& indL_dKuf, af::array& outdL_dXu, af::array& outdL_dX)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			outdL_dX = af::constant(0.0, inX.dims(), (m_dType));
			outdL_dXu = af::constant(0.0, inXu.dims(), (m_dType));

			af::array x(m_dType), xu(m_dType), tmpOutdL_dX(m_dType), tmpOutdL_dXu(m_dType);
			for (int i = 0; i < vKernel.size(); i++)
			{
				x = inX(af::span, vIndex[i]);
				xu = inXu(af::span, vIndex[i]);
				vKernel[i]->LogLikGradientX(xu, indL_dKuu, x, indL_dKuf, tmpOutdL_dXu, tmpOutdL_dX);
				outdL_dXu += tmpOutdL_dXu;
				outdL_dX += tmpOutdL_dX;
			}
		}

		template<typename Scalar>
		void CompoundKernel<Scalar>::LogLikGradientParam(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dParam)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			outdL_dParam = af::constant(0.0, 1, iNumParam, (m_dType));

			af::array tmpoutdL_dParam(m_dType);
			int startVal = 0, endVal = 0;
			for (int i = 0; i < vKernel.size(); i++)
			{
				endVal += vKernel[i]->GetNumParameter();
				vKernel[i]->LogLikGradientParam(inX1(af::span, vIndex[i]), inX2(af::span, vIndex[i]), indL_dK, tmpoutdL_dParam);
				outdL_dParam(af::seq(startVal, endVal - 1)) = tmpoutdL_dParam;
				startVal = endVal;
			}
		}

		template<typename Scalar>
		void CompoundKernel<Scalar>::GradX(const af::array& inX1, const af::array& inX2, int q, af::array& outdK_dX)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			outdK_dX = af::constant(0.0, inX1.dims(0), inX2.dims(0), (m_dType));

			af::array idx;

			af::array tmpMtx(m_dType);
			int startVal = 0, endVal = 0;
			for (int i = 0; i < vKernel.size(); i++)
			{
				idx = vIndex[i](af::where(vIndex[i] == q));

				if (!idx.isempty())
				{
					vKernel[i]->GradX(inX1(af::span, vIndex[i]), inX2(af::span, vIndex[i]), idx(0).as(f32).scalar<float>(), tmpMtx);
					outdK_dX += tmpMtx;
				}
			}
		}

		template<typename Scalar>
		void CompoundKernel<Scalar>::DiagGradX(const af::array& inX, af::array& outDiagdK_dX)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			outDiagdK_dX = af::constant(0.0, inX.dims(), (m_dType));
			af::array tmpMtx(m_dType);
			for (int i = 0; i < vKernel.size(); i++)
			{
				vKernel[i]->DiagGradX(inX(af::span, vIndex[i]), tmpMtx);
				outDiagdK_dX += tmpMtx;
			}
		}

		template<typename Scalar>
		void CompoundKernel<Scalar>::DiagGradParam(const af::array& inX, const af::array& inCovDiag, af::array& outDiagdK_dParam)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			outDiagdK_dParam = af::constant(0.0, 1, iNumParam, (m_dType));

			af::array tmpoutdL_dParam;
			int startVal = 0, endVal = 0;
			for (int i = 0; i < vKernel.size(); i++)
			{
				endVal += vKernel[i]->GetNumParameter();
				vKernel[i]->DiagGradParam(inX(af::span, vIndex[i]), inCovDiag, tmpoutdL_dParam);
				outDiagdK_dParam(startVal, endVal - 1) = tmpoutdL_dParam;
				startVal = endVal;
			}
		}

		template<typename Scalar>
		void CompoundKernel<Scalar>::SetParameters(const af::array & param)
		{
			int istart = 0, iend = 0;
			for (int i = 0; i < vKernel.size(); i++)
			{
				iend += vKernel[i]->GetNumParameter();
				vKernel[i]->SetParameters(param(af::seq(istart, iend - 1)));
				istart = iend;
			}
		}

		template<typename Scalar>
		af::array CompoundKernel<Scalar>::GetParameters()
		{
			af::array param = af::constant(0.0f, 1, GetNumParameter(), (m_dType));

			int istart = 0, iend = 0;
			for (int i = 0; i < vKernel.size(); i++)
			{
				iend += vKernel[i]->GetNumParameter();
				param(af::seq(istart, iend - 1)) = vKernel[i]->GetParameters();
				istart = iend;
			}

			return param;
		}

		template<typename Scalar>
		void CompoundKernel<Scalar>::Psi1Derivative(const af::array & inPsi1, const af::array & indL_dpsi1, const af::array & inZ, const af::array & inMu, 
			const af::array & inSu, af::array & outdL_dParam, af::array & outdL_dXu, af::array * outdL_dX)
		{
		}
	}
}