/**
File:		MachineLearning/Kernel/FgTensorKernel<Scalar>.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgTensorKernel.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		template class TensorKernel<float>;
		template class TensorKernel<double>;

		template<typename Scalar>
		TensorKernel<Scalar>::TensorKernel()
			: IKernel<Scalar>(eTensorKernel, 0)
		{
		}

		template<typename Scalar>
		TensorKernel<Scalar>::~TensorKernel()
		{
			for (std::vector<IKernel*>::iterator it = vKernel.begin(); it != vKernel.end(); ++it)
			{
				delete (*it);
			}
			vKernel.clear();

			vIndex.clear();
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::AddKernel(IKernel* kernel, const af::array index)
		{
			vKernel.push_back(kernel);
			vIndex.push_back(index);

			iNumParam += kernel->GetNumParameter();
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::ComputeKernelMatrix(const af::array& inX1, const af::array& inX2, af::array& outMatrix)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			outMatrix = af::constant(1.0, inX1.dims(0), inX2.dims(0), (m_dType));
			af::array tmpKernMtr(m_dType);
			for (int i = 0; i < vKernel.size(); i++)
			{
				vKernel[i]->ComputeKernelMatrix(inX1(af::span, vIndex[i]), inX2(af::span, vIndex[i]), tmpKernMtr);
				outMatrix *= tmpKernMtr;
			}
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::ComputeDiagonal(const af::array& inX, af::array& outDiagonal)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			outDiagonal = af::constant(1.0, inX.dims(0), (m_dType));
			af::array tmpKernMtr(m_dType);
			for (int i = 0; i < vKernel.size(); i++)
			{
				vKernel[i]->ComputeDiagonal(inX(af::span, vIndex[i]), tmpKernMtr);
				outDiagonal *= tmpKernMtr;
			}
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::LogLikGradientX(const af::array& inX, const af::array& indL_dK, af::array& outdL_dX)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			outdL_dX = af::constant(0.0f, inX.dims(), (m_dType));
			af::array dK_dX(m_dType);
			for (int q = 0; q < outdL_dX.dims(1); q++)
			{
				GradX(inX, inX, q, dK_dX);
				outdL_dX(af::span, q) = 2.0f * af::sum(indL_dK * dK_dX, 1) - af::diag(indL_dK) * af::diag(dK_dX);
			}
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::LogLikGradientX(const af::array& inXu, const af::array& indL_dKuu, const af::array& inX, const af::array& indL_dKuf, af::array& outdL_dXu, af::array& outdL_dX)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			outdL_dXu = af::constant(0.0f, inXu.dims(), (m_dType));
			outdL_dX = af::constant(0.0f, inX.dims(), (m_dType));

			af::array dK_dX_u(m_dType); // overrider for dX and dXu
			for (int q = 0; q < outdL_dXu.dims(1); q++)
			{
				// dL_dKuu_dXu
				GradX(inXu, inXu, q, dK_dX_u);
				outdL_dXu(af::span, q) = 2 * af::sum(indL_dKuu * dK_dX_u, 1) - af::diag(indL_dKuu) * af::diag(dK_dX_u);

				// dL_dKuf_dXu
				GradX(inXu, inX, q, dK_dX_u);
				outdL_dXu(af::span, q) += af::sum(indL_dKuf * dK_dX_u, 1);

				// dL_dKuf_dX
				GradX(inX, inXu, q, dK_dX_u);
				outdL_dX(af::span, q) = af::sum(indL_dKuf.T() * dK_dX_u, 1);
			}
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::LogLikGradientX(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dX)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			std::vector<IKernel*> vTmpKernel;
			af::array tmpMtx, Ktmp, idx;

			outdL_dX = af::array();

			for (int cnt = 0; cnt < vKernel.size(); cnt++)
			{
				vTmpKernel = KSlash(cnt);
				Ktmp = af::constant(1.0, inX1.dims(0), inX2.dims(0));

				for (int i = 0; i < vTmpKernel.size(); i++)
				{
					if (i < cnt)
						vTmpKernel[i]->ComputeKernelMatrix(inX1(af::span, vIndex[i]), inX2(af::span, vIndex[i]), tmpMtx);
					else
						vTmpKernel[i]->ComputeKernelMatrix(inX1(af::span, vIndex[i + 1]), inX2(af::span, vIndex[i + 1]), tmpMtx);

					Ktmp *= tmpMtx;
				}

				vKernel[cnt]->LogLikGradientX(inX1(af::span, vIndex[cnt]), inX2(af::span, vIndex[cnt]), indL_dK, tmpMtx);

				outdL_dX = CommonUtil<Scalar>::Join(outdL_dX, tmpMtx, 1);
			}
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::LogLikGradientParam(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dParam)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			std::vector<IKernel*> vTmpKernel;
			af::array tmpMtx(m_dType), Ktmp(m_dType), tmpgParam(m_dType);

			outdL_dParam = af::constant(0.0, iNumParam, (m_dType));

			int istartVal = 0, iendVal = 0;
			for (int cnt = 0; cnt < vKernel.size(); cnt++)
			{
				vTmpKernel = KSlash(cnt);
				Ktmp = af::constant(1.0, inX1.dims(0), inX2.dims(0));

				for (int i = 0; i < vTmpKernel.size(); i++)
				{
					if (i < cnt)
						vTmpKernel[i]->ComputeKernelMatrix(inX1(af::span, vIndex[i]), inX2(af::span, vIndex[i]), tmpMtx);
					else
						vTmpKernel[i]->ComputeKernelMatrix(inX1(af::span, vIndex[i + 1]), inX2(af::span, vIndex[i + 1]), tmpMtx);

					Ktmp *= tmpMtx;
				}

				vKernel[cnt]->LogLikGradientParam(inX1(af::span, vIndex[cnt]), inX2(af::span, vIndex[cnt]), indL_dK * Ktmp, tmpgParam);

				iendVal += vKernel[cnt]->GetNumParameter();
				outdL_dParam(af::seq(istartVal, iendVal - 1)) = tmpgParam;
				istartVal = iendVal;
			}
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::LogLikGradientParam(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dParam, const af::array* dlogZ_dv)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			std::vector<IKernel*> vTmpKernel;
			af::array tmpMtx(m_dType), Ktmp(m_dType), tmpgParam(m_dType);

			outdL_dParam = af::constant(0.0, iNumParam, (m_dType));

			int istartVal = 0, iendVal = 0;
			for (int cnt = 0; cnt < vKernel.size(); cnt++)
			{
				vTmpKernel = KSlash(cnt);
				Ktmp = af::constant(1.0, inX1.dims(0), inX2.dims(0));

				for (int i = 0; i < vTmpKernel.size(); i++)
				{
					if (i < cnt)
						vTmpKernel[i]->ComputeKernelMatrix(inX1(af::span, vIndex[i]), inX2(af::span, vIndex[i]), tmpMtx);
					else
						vTmpKernel[i]->ComputeKernelMatrix(inX1(af::span, vIndex[i + 1]), inX2(af::span, vIndex[i + 1]), tmpMtx);

					Ktmp *= tmpMtx;
				}

				vKernel[cnt]->LogLikGradientParam(inX1(af::span, vIndex[cnt]), inX2(af::span, vIndex[cnt]), indL_dK * Ktmp, tmpgParam, dlogZ_dv);

				iendVal += vKernel[cnt]->GetNumParameter();
				outdL_dParam(af::seq(istartVal, iendVal - 1)) = tmpgParam;
				istartVal = iendVal;
			}
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::GradX(const af::array& inX1, const af::array& inX2, int q, af::array& outdK_dX)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			std::vector<IKernel*> vTmpKernel;
			af::array x1, x2, tmpMtx, Ktmp, idx;

			/*arma::mat armaIndx;
			arma::uvec idx;*/

			outdK_dX = af::constant(0.0, inX1.dims(0), inX2.dims(0), (m_dType));

			for (int cnt = 0; cnt < vKernel.size(); cnt++)
			{
				/*armaIndx = AfArma::ArrayToMat(vIndex[cnt]);
				idx = arma::find(armaIndx == q);*/

				idx = where(vIndex[cnt] == q);

				if (!idx.isempty())
				{
					vTmpKernel = KSlash(cnt);
					Ktmp = af::constant(1.0, inX1.dims(0), inX2.dims(0));

					for (int i = 0; i < vTmpKernel.size(); i++)
					{
						if (i < cnt)
							vTmpKernel[i]->ComputeKernelMatrix(inX1(af::span, vIndex[i]), inX2(af::span, vIndex[i]), tmpMtx);
						else
							vTmpKernel[i]->ComputeKernelMatrix(inX1(af::span, vIndex[i + 1]), inX2(af::span, vIndex[i + 1]), tmpMtx);

						Ktmp *= tmpMtx;
					}

					vKernel[cnt]->GradX(inX1(af::span, vIndex[cnt]), inX2(af::span, vIndex[cnt]), idx(0).as(s32).scalar<int>(), tmpMtx);
					outdK_dX += tmpMtx;
				}
			}
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::DiagGradX(const af::array& inX, af::array& outDiagdK_dX)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			std::vector<IKernel*> vTmpKernel;
			af::array kTmp(m_dType), tmpMtx(m_dType);

			outDiagdK_dX = af::constant(0.0, inX.dims(), (m_dType));

			for (int cnt = 0; cnt < vKernel.size(); cnt++)
			{
				kTmp = af::constant(1.0, inX.dims(0), 1);

				vTmpKernel = KSlash(cnt);

				for (int i = 0; i < vTmpKernel.size(); i++)
				{
					if (i < cnt)
						vTmpKernel[i]->ComputeDiagonal(inX(af::span, vIndex[i]), tmpMtx);
					else
						vTmpKernel[i]->ComputeDiagonal(inX(af::span, vIndex[i + 1]), tmpMtx);

					kTmp *= tmpMtx;
				}

				vKernel[cnt]->DiagGradX(inX(af::span, vIndex[cnt]), tmpMtx);
				outDiagdK_dX(af::span, vIndex[cnt]) += tmpMtx * af::tile(kTmp, 1, inX(af::span, vIndex[cnt]).dims(1));
			}
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::DiagGradParam(const af::array& inX, const af::array& inCovDiag, af::array& outDiagdK_dParam)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			af::array tmpParam(m_dType);
			outDiagdK_dParam = af::constant(0.0, 1, iNumParam, (m_dType));

			for (int i = 0; i < inX.dims(0); i++)
			{
				LogLikGradientParam(inX(i, af::span), inX(i, af::span), inCovDiag(i), tmpParam);
				outDiagdK_dParam += tmpParam;
			}
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::InitParameters(Scalar inMedian)
		{
			for (int i = 0; i < vKernel.size(); i++)
				vKernel[i]->InitParameters(inMedian);
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::LogLikGradientCompundKfu(const af::array& indL_dKfu, const af::array& inX, const af::array& inXu, 
			af::array* outdL_dParam, af::array* outdL_dXu, const af::array* dlogZ_dv, af::array* outdL_dX)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			std::vector<IKernel*> vTmpKernel;
			af::array tmpMtx, Ktmp, tmpgParam, tmpgX, tmpgXu;

			*outdL_dXu = af::array();
			if (outdL_dX != nullptr)* outdL_dX = af::array();

			*outdL_dParam = af::constant(0.0, iNumParam, (m_dType));

			int istartVal = 0, iendVal = 0;
			for (int cnt = 0; cnt < vKernel.size(); cnt++)
			{
				vTmpKernel = KSlash(cnt);
				Ktmp = af::constant(1.0, inX.dims(0), inXu.dims(0), m_dType);

				for (int i = 0; i < vTmpKernel.size(); i++)
				{
					if (i < cnt)
						vTmpKernel[i]->ComputeKernelMatrix(inX(af::span, vIndex[i]), inXu(af::span, vIndex[i]), tmpMtx);
					else
						vTmpKernel[i]->ComputeKernelMatrix(inX(af::span, vIndex[i + 1]), inXu(af::span, vIndex[i + 1]), tmpMtx);

					Ktmp *= tmpMtx;
				}

				vKernel[cnt]->LogLikGradientCompundKfu(indL_dKfu * Ktmp, inX(af::span, vIndex[cnt]), inXu(af::span, vIndex[cnt]), &tmpgParam, &tmpgXu, dlogZ_dv, &tmpgX);

				iendVal += vKernel[cnt]->GetNumParameter();
				(*outdL_dParam)(af::seq(istartVal, iendVal - 1)) = tmpgParam;
				istartVal = iendVal;

				*outdL_dXu = CommonUtil<Scalar>::Join(*outdL_dXu, tmpgXu, 1);

				if (outdL_dX != nullptr)
					*outdL_dX = CommonUtil<Scalar>::Join(*outdL_dX, tmpgX, 1);
				
			}
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::LogGradientCompoundKuu(const af::array& inXu, const af::array& inCovDiag, af::array* outdL_dParam, af::array* outdL_dXu)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");

			std::vector<IKernel*> vTmpKernel;
			af::array tmpMtx, Ktmp, tmpgParam, tmpgX, tmpgXu;

			*outdL_dParam = af::constant(0.0, iNumParam, (m_dType));

			int istartVal = 0, iendVal = 0;
			for (int cnt = 0; cnt < vKernel.size(); cnt++)
			{
				vTmpKernel = KSlash(cnt);
				Ktmp = af::constant(1.0, inXu.dims(0), inXu.dims(0), m_dType);

				for (int i = 0; i < vTmpKernel.size(); i++)
				{
					if (i < cnt)
						vTmpKernel[i]->ComputeKernelMatrix(inXu(af::span, vIndex[i]), inXu(af::span, vIndex[i]), tmpMtx);
					else
						vTmpKernel[i]->ComputeKernelMatrix(inXu(af::span, vIndex[i + 1]), inXu(af::span, vIndex[i + 1]), tmpMtx);

					Ktmp *= tmpMtx;
				}

				vKernel[cnt]->LogGradientCompoundKuu(inXu(af::span, vIndex[cnt]), inCovDiag * Ktmp, &tmpgParam, &tmpgXu);

				iendVal += vKernel[cnt]->GetNumParameter();
				if (istartVal != iendVal)
					(*outdL_dParam)(af::seq(istartVal, iendVal - 1)) = tmpgParam;
				istartVal = iendVal;

				*outdL_dXu = CommonUtil<Scalar>::Join(*outdL_dXu, tmpgXu, 1);
			}
		}

		template<typename Scalar>
		std::vector<IKernel<Scalar>*> TensorKernel<Scalar>::KSlash(int kernelIndex)
		{
			std::vector<IKernel<Scalar>*> retVector;
			for (int i = 0; i < vKernel.size(); i++)
				if (i != kernelIndex) retVector.push_back(vKernel[i]);

			return retVector;
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::SetParameters(const af::array & param)
		{
			int istart, iend = 0;
			for (int i = 0; i < vKernel.size(); i++)
			{
				istart = iend;
				iend += vKernel[i]->GetNumParameter();
				if (istart != iend)
					vKernel[i]->SetParameters(param(af::seq(istart, iend - 1)));
			}
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::SetLogParameters(const af::array& param)
		{
			int istart, iend = 0;
			for (int i = 0; i < vKernel.size(); i++)
			{
				istart = iend;
				iend += vKernel[i]->GetNumParameter();
				if (istart != iend)
					vKernel[i]->SetLogParameters(param(af::seq(istart, iend - 1)));
			}
		}

		template<typename Scalar>
		af::array TensorKernel<Scalar>::GetLogParameters()
		{
			af::array param = af::constant(0.0f, GetNumParameter(), (m_dType));

			int istart, iend = 0;
			for (int i = 0; i < vKernel.size(); i++)
			{
				istart = iend;
				iend += vKernel[i]->GetNumParameter();
				if (istart != iend)
					param(af::seq(istart, iend - 1)) = vKernel[i]->GetLogParameters();
			}

			return param;
		}

		template<typename Scalar>
		af::array TensorKernel<Scalar>::GetParameters()
		{
			af::array param = af::constant(0.0f, GetNumParameter(), (m_dType));

			int istart, iend = 0;
			for (int i = 0; i < vKernel.size(); i++)
			{
				istart = iend;
				iend += vKernel[i]->GetNumParameter();
				if (istart != iend)
					param(af::seq(istart, iend - 1)) = vKernel[i]->GetParameters();
			}

			return param;
		}

		//////////////////////////////////////////////////////////////////////////////////////////////////////
		///// PSI statistics
		//////////////////////////////////////////////////////////////////////////////////////////////////////

		template<typename Scalar>
		void TensorKernel<Scalar>::ComputePsiStatistics(const af::array& inXu, const af::array& inMu, const af::array& inS, af::array& outPsi0, af::array& outPsi1, af::array& outPsi2)
		{
			LogAssert(vKernel.size() > 0, "No kernels to compute.");
			
			int iN = inS.dims(0);
			int iq = inS.dims(1);
			int ik = inXu.dims(0);

			af::array psi0Tmp, psi1Tmp, psi2Tmp;
			
			outPsi0 = af::constant(1.0, iN, m_dType);
			outPsi1 = af::constant(1.0, iN, ik, m_dType);
			outPsi2 = af::constant(1.0, ik, ik, iN, m_dType);

			for (int i = 0; i < vKernel.size(); i++)
			{
				vKernel[i]->ComputePsiStatistics(inXu(af::span, vIndex[i]), inMu(af::span, vIndex[i]), inS(af::span, vIndex[i]), psi0Tmp, psi1Tmp, psi2Tmp);
				outPsi0 *= psi0Tmp;
				outPsi1 *= psi1Tmp;
				outPsi2 *= psi2Tmp;
			}
		}

		template<typename Scalar>
		void TensorKernel<Scalar>::PsiDerivatives(const af::array& indL_dPsi0, const af::array& inPsi1, const af::array& indL_dPsi1, const af::array& inPsi2, const af::array& indL_dPsi2, 
			const af::array& inXu, const af::array& inMu, const af::array& inS, af::array& outdL_dParam, af::array& outdL_dXu, af::array& outdL_dMu, af::array& outdL_dS, const af::array* dlogZ_dv)
		{
			std::vector<IKernel*> vTmpKernel;
			af::array tmpdL_dMu, tmpdL_dS, tmpgParam, tmpgX, tmpgXu, psi0, psi1, psi2, psi0tmp, psi1tmp, psi2tmp, vPsi0, vPsi1, vPsi2;

			int iN = inS.dims(0);
			int iq = inS.dims(1);
			int ik = inXu.dims(0);
			
			outdL_dParam = af::constant(0.0, iNumParam, (m_dType));

			int istartVal = 0, iendVal = 0;
			for (int cnt = 0; cnt < vKernel.size(); cnt++)
			{
				vTmpKernel = KSlash(cnt);
				psi0 = af::constant(1.0, iN, m_dType);
				psi1 = af::constant(1.0, iN, ik, m_dType);
				psi2 = af::constant(1.0, ik, ik, iN, m_dType);

				for (int i = 0; i < vTmpKernel.size(); i++)
				{
					if (i < cnt)
						vTmpKernel[i]->ComputePsiStatistics(inXu(af::span, vIndex[i]), inMu(af::span, vIndex[i]), inS(af::span, vIndex[i]), psi0tmp, psi1tmp, psi2tmp);
					else
						vTmpKernel[i]->ComputePsiStatistics(inXu(af::span, vIndex[i + 1]), inMu(af::span, vIndex[i + 1]), inS(af::span, vIndex[i + 1]), psi0tmp, psi1tmp, psi2tmp);

					psi0 *= psi0tmp;
					psi1 *= psi1tmp;
					psi2 *= psi2tmp;
				}
				vKernel[cnt]->ComputePsiStatistics(inXu(af::span, vIndex[cnt]), inMu(af::span, vIndex[cnt]), inS(af::span, vIndex[cnt]), vPsi0, vPsi1, vPsi2);
				vKernel[cnt]->PsiDerivatives(indL_dPsi0 * psi0, vPsi1, indL_dPsi1 * psi1, vPsi2, indL_dPsi2 * psi2, inXu(af::span, vIndex[cnt]),
					inMu(af::span, vIndex[cnt]), inS(af::span, vIndex[cnt]), tmpgParam, tmpgXu, tmpdL_dMu, tmpdL_dS, dlogZ_dv);

				iendVal += vKernel[cnt]->GetNumParameter();
				if (istartVal != iendVal)
					outdL_dParam(af::seq(istartVal, iendVal - 1)) = tmpgParam;
				istartVal = iendVal;

				outdL_dXu = CommonUtil<Scalar>::Join(outdL_dXu, tmpgXu, 1);
				outdL_dMu = CommonUtil<Scalar>::Join(outdL_dMu, tmpdL_dMu, 1);
				outdL_dS = CommonUtil<Scalar>::Join(outdL_dS, tmpdL_dS, 1);
			}
		}

		//template<typename Scalar>
		//void TensorKernel<Scalar>::ComputePsi1(const af::array& inXu, const af::array& inMu, const af::array& inS, af::array& outPsi1)
		//{
		//}

		//template<typename Scalar>
		//void TensorKernel<Scalar>::ComputePsi2(const af::array& inXu, const af::array& inMu, const af::array& inS, af::array& outPsi2)
		//{
		//}

		//template<typename Scalar>
		//void TensorKernel<Scalar>::Psi1Derivative(const af::array& inPsi1, const af::array& indL_dPsi1, const af::array& inXu, const af::array& inMu, const af::array& inS, af::array& outdL_dParam, af::array& outdL_dXu, af::array& outdL_dMu, af::array& outdL_dS)
		//{
		//}

		//template<typename Scalar>
		//void TensorKernel<Scalar>::Psi2Derivative(const af::array& inPsi2, const af::array& indL_dPsi2, const af::array& inXu, const af::array& inMu, const af::array& inS, af::array& outdL_dParam, af::array& outdL_dXu, af::array& outdL_dMu, af::array& outdL_dS)
		//{
		//}
	}
}