/**
File:		MachineLearning/Kernel/FgLinearKernel<Scalar>.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgLinearKernel.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		template class LinearKernel<float>;
		template class LinearKernel<double>;

		template<typename Scalar>
		LinearKernel<Scalar>::LinearKernel()
			: IKernel<Scalar>(eLinearKernel, 1), dVariance(af::constant(1.0f, 1, m_dType))
		{
		}

		template<typename Scalar>
		LinearKernel<Scalar>::LinearKernel(int numdims)
			: IKernel<Scalar>(eLinearKernel, numdims), dVariance(af::constant(1.0f, numdims, m_dType))
		{
		}

		template<typename Scalar>
		LinearKernel<Scalar>::~LinearKernel() { }

		template<typename Scalar>
		void LinearKernel<Scalar>::ComputeKernelMatrix(const af::array& inX1, const af::array& inX2, af::array& outMatrix)
		{
			SetLogParameters(GetLogParameters() * 2.0);
			if (GetNumParameter() == 1)
				outMatrix = af::tile(dVariance, inX1.dims(0), inX2.dims(0)) * af::matmulNT(inX1, inX2);
			else
			{
				LogAssert(dVariance.dims(0) == inX1.dims(1), "Dimension Missmatch for Variance.");

				af::array tmpVar = sqrt(dVariance);
				outMatrix = af::matmulNT(inX1 * af::tile(tmpVar.T(), inX1.dims(0)), inX2 * af::tile(tmpVar.T(), inX2.dims(0)));
			}
			SetLogParameters(GetLogParameters() / 2.0);
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::ComputeDiagonal(const af::array& inX, af::array& outDiagonal)
		{
			SetLogParameters(GetLogParameters() * 2.0);
			if (GetNumParameter() == 1)
				outDiagonal = af::sum(inX * inX * af::tile(dVariance, inX.dims()), 1);
			else
				outDiagonal = af::sum(inX * inX * af::tile(dVariance.T(), inX.dims(0)), 1);
			SetLogParameters(GetLogParameters() / 2.0);
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::LogLikGradientX(const af::array& inX, const af::array& indL_dK, af::array& outdL_dX)
		{
			outdL_dX = af::constant(0.0f, inX.dims(), (m_dType));
			af::array dK_dX(m_dType);
			for (int q = 0; q < outdL_dX.dims(1); q++)
			{
				GradX(inX, inX, q, dK_dX);
				outdL_dX(af::span, q) = 2.0f * af::sum(indL_dK * dK_dX, 1) - af::diag(indL_dK) * af::diag(dK_dX);
			}
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::LogLikGradientX(const af::array& inXu, const af::array& indL_dKuu, const af::array& inX, const af::array& indL_dKuf, af::array& outdL_dXu, af::array& outdL_dX)
		{
			outdL_dXu = af::constant(0.0f, inXu.dims(), (m_dType));
			outdL_dX = af::constant(0.0f, inX.dims(), (m_dType));

			af::array dK_dX_u(m_dType); // overrider for dX and dXu
			for (int q = 0; q < outdL_dXu.dims(1); q++)
			{
				// dL_dKuu_dXu
				GradX(inXu, inXu, q, dK_dX_u);
				outdL_dXu(af::span, q) = 2 * af::sum(indL_dKuu * dK_dX_u, 1) - af::diag(indL_dKuu) * af::diag(dK_dX_u);

				// dL_dKuf_dXu
				GradX(inXu, inX, q, dK_dX_u);
				outdL_dXu(af::span, q) += af::sum(indL_dKuf * dK_dX_u, 1);

				// dL_dKuf_dX
				GradX(inX, inXu, q, dK_dX_u);
				outdL_dX(af::span, q) = af::sum(indL_dKuf.T() * dK_dX_u, 1);
			}
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::LogLikGradientX(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dX)
		{
			SetLogParameters(GetLogParameters() * 2.0);

			if (inX1.dims(0) == inX2.dims(0))
			{
				af::array dL_dK = (indL_dK + indL_dK.T()) / 2.0;
				if (GetNumParameter() == 1)
					outdL_dX = matmul(dL_dK, inX1) * af::tile(2.0 * dVariance, inX1.dims());
				else
					outdL_dX = matmul(dL_dK, inX1) * af::tile(2.0 * dVariance.T(), inX1.dims(0));
			}
			else
			{
				if (GetNumParameter() == 1)
					outdL_dX = matmul(indL_dK, inX2) * af::tile(dVariance, inX1.dims());
				else
					outdL_dX = matmul(indL_dK, inX2) * af::tile(dVariance.T(), inX1.dims(0));
			}

			SetLogParameters(GetLogParameters() / 2.0);
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::LogLikGradientParam(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dParam)
		{
			af::array dL_dKtmp;

			if (inX1.dims(0) == inX2.dims(0)) (dL_dKtmp = indL_dK + indL_dK.T()) / 2.0;
			else dL_dKtmp = indL_dK;

			if (GetNumParameter() == 1)
				outdL_dParam = sum(sum(matmulNT(inX1, inX2) * dL_dKtmp));
			else
				outdL_dParam = sum(matmul(dL_dKtmp, inX2) * inX1).T();
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::LogLikGradientParam(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dParam, const af::array* dlogZ_dv)
		{
			af::array dL_dKtmp;

			if (inX1.dims(0) == inX2.dims(0)) (dL_dKtmp = indL_dK + indL_dK.T()) / 2.0;
			else dL_dKtmp = indL_dK;

			if (GetNumParameter() == 1)
				outdL_dParam = af::sum(af::sum(af::matmul(dL_dKtmp, inX2) * inX1));
			else
				outdL_dParam = af::sum(af::matmul(dL_dKtmp, inX2) * inX1).T();

			outdL_dParam += af::tile(af::sum(af::sum(*dlogZ_dv)), outdL_dParam.dims());

			outdL_dParam *= 2.0f;
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::GradX(const af::array& inX1, const af::array& inX2, int q, af::array& outdK_dX)
		{
			int numData = inX1.dims(0);
			int numData2 = inX2.dims(0);

			SetLogParameters(GetLogParameters() * 2.0);

			if (numData == numData2)
			{
				outdK_dX = af::tile(inX1(af::span, q).T(), numData, 1);
				outdK_dX += af::diag(outdK_dX, 0, false);
			}
			else
				outdK_dX = af::tile(inX2(af::span, q).T(), numData, 1);

			if (GetNumParameter() == 1)
				outdK_dX *= af::tile(dVariance, outdK_dX.dims());
			else
				outdK_dX *= af::tile(dVariance(q), outdK_dX.dims());

			SetLogParameters(GetLogParameters() / 2.0);
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::DiagGradX(const af::array& inX, af::array& outDiagdK_dX)
		{
			if (GetNumParameter() == 1)
				outDiagdK_dX = 2.0f * af::tile(dVariance, inX.dims()) * inX;
			else
				outDiagdK_dX = 2.0f * af::tile(dVariance.T(), inX.dims(0)) * inX;
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::DiagGradParam(const af::array& inX, const af::array& inCovDiag, af::array& outDiagdK_dParam)
		{
			Scalar gParam = 0.0f;

			for (int i = 0; i < inX.dims(0); i++)
				gParam += (af::matmulNT(inX(i, af::span), inX(i, af::span).T()) * inCovDiag(i) / dVariance).scalar<float>();

			outDiagdK_dParam = af::array(1, 1, (m_dType));
			outDiagdK_dParam(0) = gParam * dVariance; // consider log space
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::InitParameters(Scalar inMedian)
		{
			SetLogParameters(af::constant(log(0.5), GetNumParameter(), m_dType));
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::LogLikGradientCompundKfu(const af::array& indL_dKfu, const af::array& inX, const af::array& inXu, 
			af::array* outdL_dParam, af::array* outdL_dXu, const af::array* dlogZ_dv, af::array* outdL_dX)
		{
			af::array Lpsi1, kfu;

			int iN = inX.dims(0);
			int ik = inXu.dims(0);
			int iq = inXu.dims(1);

			SetLogParameters(GetLogParameters() * 2.0);

			Lpsi1 = matmul(indL_dKfu, inXu) * inX;

			if (GetNumParameter() == 1)
			{
				*outdL_dParam = sum(sum(Lpsi1));

				*outdL_dXu = matmulTN(indL_dKfu, inX) * af::tile(dVariance, ik, iq);

				if (outdL_dX != nullptr)
					* outdL_dX = matmul(indL_dKfu, inXu) * af::tile(dVariance, iN, iq);
			}
			else
			{
				*outdL_dParam = sum(Lpsi1).T();

				*outdL_dXu = matmulTN(indL_dKfu, inX) * af::tile(dVariance.T(), ik);

				if (outdL_dX != nullptr)
					*outdL_dX = matmul(indL_dKfu, inXu) * af::tile(dVariance.T(), iN);
			}

			/**outdL_dXu = af::moddims(sum(af::tile(af::moddims(inX, iN, 1, iq), 1, ik, 1) * af::tile(indL_dKfu, 1, 1, iq), 0), ik, iq);
			
			if (outdL_dX != nullptr)
				*outdL_dX = af::moddims(sum(af::tile(af::moddims(inXu, ik, 1, iq), 1, iN, 1) * af::tile(indL_dKfu.T(), 1, 1, iq), 0), iN, iq);*/

			if (dlogZ_dv)
				*outdL_dParam += af::tile(af::sum(af::sum(*dlogZ_dv)), outdL_dParam->dims());
			
			(*outdL_dParam) *= 2.0;

			SetLogParameters(GetLogParameters() / 2.0);
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::LogGradientCompoundKuu(const af::array& inXu, const af::array& inM, af::array* outdL_dParam, af::array* outdL_dXu)
		{
			af::array inMTmp;

			SetLogParameters(GetLogParameters() * 2.0);

			int ik = inXu.dims(0);
			int iq = inXu.dims(1);

			inMTmp = (inM + inM.T()) / 2.0;

			if (GetNumParameter() == 1)
			{
				*outdL_dParam = af::sum(af::sum(af::matmul(inMTmp, inXu) * inXu));
				*outdL_dXu = af::matmul(inMTmp, inXu) * 2.0 * tile(dVariance, ik, iq);
			}
			else
			{
				*outdL_dParam = af::sum(af::matmul(inMTmp, inXu) * inXu).T();
				*outdL_dXu = af::matmul(inMTmp, inXu) * 2.0 * tile(dVariance.T(), ik);
			}
			SetLogParameters(GetLogParameters() / 2.0);
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::SetParameters(const af::array & param)
		{
			dVariance = param;
		}

		template<typename Scalar>
		af::array LinearKernel<Scalar>::GetParameters()
		{
			return dVariance;
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::SetLogParameters(const af::array& param)
		{
			dVariance = exp(param);
		}

		template<typename Scalar>
		af::array LinearKernel<Scalar>::GetLogParameters()
		{
			return log(dVariance);
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::ComputePsiStatistics(const af::array& inXu, const af::array& inMu, const af::array& inS, af::array& outPsi0, af::array& outPsi1, af::array& outPsi2)
		{
			int iN = inS.dims(0);
			int iq = inS.dims(1);
			int ik = inXu.dims(0);

			af::array vXu, tmpPsi;

			SetLogParameters(GetLogParameters() * 2.0);

			if (GetNumParameter() == 1)
			{
				vXu = af::tile(dVariance, inXu.dims()) * inXu;
				outPsi0 = af::sum(af::tile(dVariance, inMu.dims()) * (pow(inMu, 2.0) + inS), 1);
			}	
			else
			{
				vXu = af::tile(dVariance.T(), inXu.dims(0)) * inXu;
				outPsi0 = af::sum(af::tile(dVariance.T(), inMu.dims(0)) * (pow(inMu, 2.0) + inS), 1);
			}

			outPsi1 = af::matmulNT(inMu, vXu);

			tmpPsi = af::moddims(outPsi1.T(), ik, 1, iN);
			outPsi2 = af::tile(tmpPsi, 1, ik) * af::tile(af::moddims(tmpPsi, 1, ik, iN), ik, 1, 1) + af::matmulNT(af::tile(af::moddims(inS.T(), 1, iq, iN), ik) * tile(vXu, 1, 1, iN), vXu);

			SetLogParameters(GetLogParameters() / 2.0);
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::PsiDerivatives(const af::array& indL_dPsi0, const af::array& inPsi1, const af::array& indL_dPsi1, const af::array& inPsi2, const af::array& indL_dPsi2, 
			const af::array& inXu, const af::array& inMu, const af::array& inS, af::array& outdL_dParam, af::array& outdL_dXu, af::array& outdL_dMu, af::array& outdL_dS, const af::array* dlogZ_dv)
		{
			af::array mu2S, dL_dpsi1_mu;

			SetLogParameters(GetLogParameters() * 2.0);

			Psi2Derivative(indL_dPsi2, inXu, inMu, inS, outdL_dParam, outdL_dXu, outdL_dMu, outdL_dS);

			mu2S = af::pow(inMu, 2.0) + inS;
			dL_dpsi1_mu = af::matmulTN(indL_dPsi1, inMu);

			if (GetNumParameter() == 1)
			{
				outdL_dParam += af::sum(af::sum(dL_dpsi1_mu * inXu));
				outdL_dMu += af::matmul(indL_dPsi1, inXu) * af::tile(dVariance, indL_dPsi1.dims(0), inXu.dims(1));
				//outdL_dS += dL_dpsi0_var;
				outdL_dXu += dL_dpsi1_mu * af::tile(dVariance, dL_dpsi1_mu.dims());
			}
			else
			{
				outdL_dParam += af::sum(dL_dpsi1_mu * inXu).T();
				outdL_dMu += af::matmul(indL_dPsi1, inXu) * af::tile(dVariance.T(), indL_dPsi1.dims(0));
				//outdL_dS += dL_dpsi0_var;
				outdL_dXu += dL_dpsi1_mu * af::tile(dVariance.T(), dL_dpsi1_mu.dims(0));
			}

			if (dlogZ_dv)
				outdL_dParam += af::tile(af::sum(af::sum(*dlogZ_dv)), outdL_dParam.dims());

			(outdL_dParam) *= 2.0f;

			SetLogParameters(GetLogParameters() / 2.0);
		}

		template<typename Scalar>
		void LinearKernel<Scalar>::Psi2Derivative(const af::array& indL_dPsi2, const af::array& inXu, const af::array& inMu, const af::array& inS, 
			af::array& outdL_dParam, af::array& outdL_dXu, af::array& outdL_dMu, af::array& outdL_dS)
		{
			af::array Z_expect, dL_dpsi2T, common_expect, Z2_expect, Z1_expect, common_sum, dL_dpsi2_;

			af::array variance2 = af::pow(dVariance, 2.0);

			int iN = inS.dims(0);
			int ik = inXu.dims(0);
			int iq = inS.dims(1);

			if (GetNumParameter() == 1)
				common_sum = af::matmulNT(inMu, af::tile(dVariance, inXu.dims()) * inXu); //Nxk
			else
				common_sum = af::matmulNT(inMu, af::tile(dVariance.T(), inXu.dims(0)) * inXu); //Nxk

			if (indL_dPsi2.numdims() == 2)
			{
				Z_expect = af::sum(af::matmul(indL_dPsi2, inXu) * inXu);
				dL_dpsi2T = indL_dPsi2 + indL_dPsi2.T();
				common_expect = af::matmul(common_sum, af::matmul(dL_dpsi2T, inXu));

				Z2_expect = af::dot(common_sum, dL_dpsi2T);
				Z1_expect = af::matmul(dL_dpsi2T, inXu);

				if (GetNumParameter() == 1)
				{
					outdL_dParam = af::sum(2.0 * af::sum(inS) * af::tile(dVariance, 1, iq) * Z_expect + af::sum(common_expect * inMu));
					outdL_dMu = common_expect * af::tile(dVariance, iN, iq);
					outdL_dS = af::tile(Z_expect * af::tile(variance2, 1, iq), inS.dims(0));
					outdL_dXu = af::tile(variance2, 1, iq) * af::sum(inS) * Z1_expect + af::matmulTN(Z2_expect, af::tile(dVariance, inMu.dims()) * inMu);
				}	
				else
				{
					outdL_dParam = (2.0 * af::sum(inS) * dVariance.T() * Z_expect + af::sum(common_expect * inMu)).T();
					outdL_dMu = common_expect * af::tile(dVariance.T(), iN);
					outdL_dS = af::tile(Z_expect * variance2.T(), inS.dims(0));
					outdL_dXu = variance2.T() * af::sum(inS) * Z1_expect + af::matmulTN(Z2_expect, af::tile(dVariance.T(), inMu.dims(0)) * inMu);
				}
					
			}
			else
			{
				Z_expect = af::moddims(af::sum(af::moddims(af::moddims(matmul(indL_dPsi2, af::tile(inXu, 1, 1, iN)), ik * iq, iN).T(), iN, ik, iq)
					* af::tile(af::moddims(inXu, 1, ik, iq), iN), 1), iN, iq); //NxQ
				dL_dpsi2T = indL_dPsi2 + indL_dPsi2.T(); //kxkxN

				common_expect = af::moddims(af::sum(af::tile(common_sum, 1, 1, iq)
					* af::moddims(af::moddims(matmul(dL_dpsi2T, af::tile(inXu, 1, 1, iN)), ik * iq, iN).T(), iN, ik, iq), 1), iN, iq); //NxQ
				Z2_expect = af::moddims(af::sum(tile(af::moddims(common_sum.T(), ik, 1, iN), 1, ik, 1) * dL_dpsi2T, 1), ik, iN).T(); //Nxk
				Z1_expect = af::moddims(af::moddims(af::matmul(dL_dpsi2T, af::tile(inXu, 1, 1, iN)), ik * iq, iN).T(), iN, ik, iq); // NxkxQ

				if (GetNumParameter() == 1)
				{
					outdL_dParam = af::sum(2.0 * af::tile(dVariance, 1, iq) * af::sum(inS * Z_expect) + af::sum(common_expect * inMu));
					outdL_dMu = common_expect * af::tile(dVariance, iN, iq);
					outdL_dS = Z_expect * af::tile(variance2, iN, iq);
					outdL_dXu = af::tile(variance2, ik, iq) * af::moddims(af::sum(af::tile(af::moddims(inS, iN, 1, iq), 1, ik, 1) * Z1_expect), ik, iq) + af::matmulTN(Z2_expect, tile(dVariance, iN, iq) * inMu);
				}
				else
				{
					outdL_dParam = (2.0 * dVariance.T() * af::sum(inS * Z_expect) + af::sum(common_expect * inMu)).T();
					outdL_dMu = common_expect * af::tile(dVariance.T(), iN);
					outdL_dS = Z_expect * af::tile(variance2.T(), iN);
					outdL_dXu = af::tile(variance2.T(), ik) * af::moddims(sum(tile(af::moddims(inS, iN, 1, iq), 1, ik, 1) * Z1_expect), ik, iq) + af::matmulTN(Z2_expect, tile(dVariance.T(), iN) * inMu);
				}
			}
		}
	}
}