/**
File:		MachineLearning/Kernel/FgARDKernel<Scalar>.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgARDKernel.h>
#include <MachineLearning/CommonUtil.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		template class ARDKernel<float>;
		template class ARDKernel<double>;

		template<typename Scalar>
		ARDKernel<Scalar>::ARDKernel(int numDim)
			: IKernel<Scalar>(eARDKernel, numDim + 1), dVariance(1.0f)
		{
			dInvScale = af::constant(1.0f, numDim, m_dType);
		}

		template<typename Scalar>
		ARDKernel<Scalar>::~ARDKernel() { }

		template<typename Scalar>
		void ARDKernel<Scalar>::ComputeKernelMatrix(const af::array& inX1, const af::array& inX2, af::array& outMatrix)
		{
			SetLogParameters(GetLogParameters() * 2.0);
			af::array scale = af::sqrt(dInvScale).T(), n2(m_dType);
			n2 = CommonUtil<Scalar>::SquareDistance(inX1 / af::tile(scale, inX1.dims(0)), inX2 / af::tile(scale, inX2.dims(0)));
			outMatrix = dVariance * af::exp(-0.5f * n2);
			SetLogParameters(GetLogParameters() / 2.0);
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::ComputeDiagonal(const af::array& inX, af::array& outDiagonal)
		{
			SetLogParameters(GetLogParameters() * 2.0);
			outDiagonal = af::constant(dVariance, inX.dims(0), m_dType);
			SetLogParameters(GetLogParameters() / 2.0);
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::LogLikGradientX(const af::array& inX, const af::array& indL_dK, af::array& outdL_dX)
		{
			outdL_dX = af::constant(0.0f, inX.dims(), m_dType);
			af::array dK_dX;
			for (int q = 0; q < outdL_dX.dims(1); q++)
			{
				GradX(inX, inX, q, dK_dX);
				outdL_dX(af::span, q) = 2.0f * af::sum(indL_dK * dK_dX, 1) - af::diag(indL_dK) * af::diag(dK_dX);
			}
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::LogLikGradientX(const af::array& inXu, const af::array& indL_dKuu, const af::array& inX, const  af::array& indL_dKuf, af::array& outdL_dXu, af::array& outdL_dX)
		{
			outdL_dXu = af::constant(0.0f, inXu.dims(), m_dType);
			outdL_dX = af::constant(0.0f, inX.dims()), m_dType;

			af::array dK_dX_u; // overrider for dX and dXu
			for (int q = 0; q < outdL_dXu.dims(1); q++)
			{
				// dL_dKuu_dXu
				GradX(inXu, inXu, q, dK_dX_u);
				outdL_dXu(af::span, q) = 2 * af::sum(indL_dKuu * dK_dX_u, 1) - af::diag(indL_dKuu) * af::diag(dK_dX_u);

				// dL_dKuf_dXu
				GradX(inXu, inX, q, dK_dX_u);
				outdL_dXu(af::span, q) += af::sum(indL_dKuf * dK_dX_u, 1);

				// dL_dKuf_dX
				GradX(inX, inXu, q, dK_dX_u);
				outdL_dX(af::span, q) = af::sum(indL_dKuf.T() * dK_dX_u, 1);
			}
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::LogLikGradientX(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dX)
		{
			int numData = inX1.dims(0);
			int numData2 = inX2.dims(0);
			af::array scale, tmp, invDist, dL_dn, dL_dK, K, K1, K2, dK_dX, tempParam;

			/*scale = sqrt(dInvScale).T();
			tmp = CommonUtil<double>::UnscaledDistance(inX1 / tile(scale, inX1.dims(0)), inX2 / tile(scale, inX2.dims(0)));
			invDist = tmp.copy();
			invDist(where(invDist == 0.0)) = af::Inf;
			invDist = 1.0 / invDist;

			tempParam /= 2.0;
			SetLogParameters(tempParam);

			ComputeKernelMatrix(inX1, inX2, K);
			dL_dn = -tmp * K * indL_dK;
			dL_dK = dL_dn * invDist;

			if (inX1.dims(0) == inX2.dims(0)) dL_dK = dL_dK + dL_dK.T();

			outdL_dX = constant(0.0, inX1.dims(), m_dType);
			for (auto q = 0; q < inX1.dims(1); q++)
			{
				K1 = af::tile(inX1(af::span, q), 1, numData2);
				K2 = af::tile(inX2(af::span, q).T(), numData, 1);

				outdL_dX(span, q) = sum((K1 - K2) * dL_dK, 1);
			}

			outdL_dX /= tile(pow(dInvScale, 2.0).T(), outdL_dX.dims(0));*/

			ComputeKernelMatrix(inX1, inX2, K);

			/*tempParam = GetLogParameters();
			tempParam *= 2.0;
			SetLogParameters(tempParam);*/

			outdL_dX = constant(0.0, inX1.dims(), m_dType);
			for (auto q = 0; q < inX1.dims(1); q++)
			{
				K1 = af::tile(inX1(af::span, q) / af::tile(af::pow(dInvScale(q), 2.0f), numData), 1, numData2);
				K2 = af::tile((inX2(af::span, q) / af::tile(af::pow(dInvScale(q), 2.0f), numData2)).T(), numData, 1);

				if (K1.dims(0) == K.dims(0))
					dK_dX = -(K1 - K2) * K;
				else
					dK_dX = -(K1 - K2) * K.T();

				outdL_dX(af::span, q) = af::sum(dK_dX * indL_dK, 1);
			}
			/*tempParam /= 2.0;
			SetLogParameters(tempParam);*/
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::LogLikGradientParam(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dParam)
		{
			af::array K, n2, covGradK, tempParam, dL_dn, invDist, scale, tmp;

			ComputeKernelMatrix(inX1, inX2, K);
			
			tempParam = GetLogParameters();
			tempParam(0) *= 2.0;
			SetLogParameters(tempParam);

			/*scale = sqrt(dInvScale).T();
			tmp = CommonUtil<double>::UnscaledDistance(inX1 / tile(scale, inX1.dims(0)), inX2 / tile(scale, inX2.dims(0)));
			invDist = tmp.copy();
			invDist(where(invDist == 0.0)) = af::Inf;
			invDist = 1.0 / invDist;

			dL_dn = -tmp * K * indL_dK * invDist;*/

			covGradK = indL_dK * K;

			outdL_dParam = af::constant(0.0f, inX1.dims(1) + 1);

			outdL_dParam(0) = af::sum(af::sum(covGradK / dVariance));

			if (inX1.dims(0) == inX2.dims(0))
				for (int i = 0; i < inX1.dims(1); i++)
					outdL_dParam(i + 1) = -af::sum(af::matmul(covGradK, inX1(af::span, i) * inX1(af::span, i)))
						- af::matmulTN(inX1(af::span, i), af::matmul(covGradK, inX1(af::span, i)));
			else
				for (int i = 0; i < inX1.dims(1); i++)
					outdL_dParam(i + 1) = -(0.5 * af::sum(af::matmulTN(covGradK, (inX1(af::span, i) * inX1(af::span, i))))
						+ 0.5 * af::sum(af::matmul(covGradK, (inX2(af::span, i) * inX2(af::span, i))))
						- af::matmul(af::matmulTN(inX1(af::span, i), covGradK), inX2(af::span, i))) / af::pow(dInvScale(i), 3);

			tempParam(0) /= 2.0;
			SetLogParameters(tempParam);
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::LogLikGradientParam(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dParam, const af::array* dlogZ_dv)
		{
			af::array K, n2, covGradK, tempParam, dL_dn, invDist, scale, tmp;

			ComputeKernelMatrix(inX1, inX2, K);

			tempParam = GetLogParameters();
			tempParam(0) *= 2.0;
			SetLogParameters(tempParam);

			scale = af::sqrt(dInvScale).T();
			tmp = CommonUtil<double>::UnscaledDistance(inX1 / af::tile(scale, inX1.dims(0)), inX2 / af::tile(scale, inX2.dims(0)));
			invDist = tmp.copy();
			invDist(af::where(invDist == 0.0)) = af::Inf;
			invDist = 1.0 / invDist;

			dL_dn = -tmp * K * indL_dK * invDist;

			covGradK = indL_dK * K;

			outdL_dParam = af::constant(0.0f, inX1.dims(1) + 1);

			outdL_dParam(0) = af::sum(af::sum(covGradK / dVariance));

			for (int i = 0; i < inX1.dims(1); i++)
				outdL_dParam(i + 1) = -(0.5 * af::sum(af::matmulTN(covGradK, (inX1(af::span, i) * inX1(af::span, i))))
					+ 0.5 * af::sum(af::matmul(covGradK, (inX2(af::span, i) * inX2(af::span, i))))
					- af::matmul(af::matmulTN(inX1(af::span, i), covGradK), inX2(af::span, i))) / af::pow(dInvScale(i), 3);

			/*for (int i = 0; i < inX1.dims(1); i++)
				outdL_dParam(i + 1) = -(af::sum(af::matmulTN(dL_dn, (inX1(span, i) * inX1(span, i))))
					+ af::sum(af::matmul(dL_dn, (inX2(span, i) * inX2(span, i))))
					- 2.0 * af::matmul(af::matmulTN(inX1(span, i), dL_dn), inX2(span, i))) / pow(dInvScale(i), 3);*/

			outdL_dParam(0) += (af::sum(af::sum(*dlogZ_dv)));

			outdL_dParam(0) *= 2.0f;

			tempParam(0) /= 2.0;
			SetLogParameters(tempParam);
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::GradX(const af::array& inX1, const af::array& inX2, int q, af::array& outdK_dX)
		{
			int numData = inX1.dims(0);
			int numData2 = inX2.dims(0);

			af::array scale = (dInvScale.dims(1) > 1) ? af::diag(dInvScale, 0, true) : dInvScale;

			af::array K1 = af::tile(inX1(af::span, q) / af::tile(af::pow(dInvScale(q), 2.0f), numData), 1, numData2);
			af::array K2 = af::tile((inX2(af::span, q) / af::tile(af::pow(dInvScale(q), 2.0f), numData2)).T(), numData, 1);

			af::array K;
				
			ComputeKernelMatrix(inX1, inX2, K);

			if (K1.dims(0) == K.dims(0))
				outdK_dX = -(K1 - K2) * K;
			else
				outdK_dX = -(K1 - K2) * K.T();
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::DiagGradX(const af::array& inX, af::array& outDiagdK_dX)
		{
			outDiagdK_dX = af::constant(0.0f, inX.dims(), (m_dType));
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::DiagGradParam(const af::array& inX, const af::array& inCovDiag, af::array& outDiagdK_dParam)
		{
			outDiagdK_dParam = af::constant(0.0f, inX.dims(1) + 1, (m_dType));

			//outDiagdK_dParam(0) = 0.0f;
			//outDiagdK_dParam(1) = af::sum(inCovDiag);
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::LogLikGradientCompundKfu(const af::array & indL_dKfu, const af::array & inX, const af::array & inXu, 
			af::array * outdL_dParam, af::array * outdL_dXu, const af::array* dlogZ_dv, af::array * outdL_dX)
		{
			af::array Lpsi1, Zmu, kfu, tempParam;

			int iN = inX.dims(0);
			int ik = inXu.dims(0);
			int iq = inXu.dims(1);

			ComputeKernelMatrix(inX, inXu, kfu);

			tempParam = GetLogParameters();
			tempParam(0) *= 2.0;
			SetLogParameters(tempParam);

			*outdL_dParam = af::constant(0.0f, GetNumParameter(), (m_dType));

			af::array dInvScale2 = af::pow(dInvScale, 2);

			Lpsi1 = indL_dKfu * kfu;
			Zmu = af::tile(af::moddims(inXu, 1, ik, iq), iN, 1, 1) - af::tile(af::moddims(inX, iN, 1, iq), 1, ik, 1);

			(*outdL_dParam)(0) = af::sum(af::sum(Lpsi1)) / dVariance;
			if (dlogZ_dv)
				(*outdL_dParam)(0) += (af::sum(af::sum(*dlogZ_dv)));

			(*outdL_dParam)(0) *= 2.0f;

			(*outdL_dParam)(af::seq(1, GetNumParameter() - 1)) = af::moddims(af::sum(af::sum(af::tile(Lpsi1, 1, 1, iq) * af::pow(Zmu, 2)
				/ af::tile(af::moddims(af::pow(dInvScale, 3), 1, 1, iq), iN, ik))), iq, 1, 1);

			*outdL_dXu = af::moddims(-af::sum(af::tile(Lpsi1, 1, 1, iq) * Zmu / af::tile(af::moddims(dInvScale2, 1, 1, iq), iN, ik), 0), ik, iq);

			if (outdL_dX != nullptr)
				*outdL_dX = af::moddims(af::sum(af::tile(Lpsi1, 1, 1, iq) * Zmu / af::tile(af::moddims(dInvScale2, 1, 1, iq), iN, ik), 1), iN, iq);

			(*outdL_dParam) *= GetParameters();

			tempParam(0) /= 2.0;
			SetLogParameters(tempParam);
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::LogGradientCompoundKuu(const af::array & inXu, const af::array & inM, af::array * outdL_dParam, af::array * outdL_dXu)
		{
			af::array kuu(m_dType), Ml(m_dType), Xl(m_dType), Xbar(m_dType), Mbar1(m_dType), Mbar2(m_dType);

			int ik = inXu.dims(0);
			int iq = inXu.dims(1);

			ComputeKernelMatrix(inXu, inXu, kuu);

			SetLogParameters(GetLogParameters() * 2.0);

			af::array scale = /*exp*/(dInvScale).T();

			*outdL_dParam = af::constant(0.0f, GetNumParameter(), (m_dType));

			(*outdL_dParam)(0) = af::sum(af::sum(inM * kuu));
			
			Ml = 0.5 * inM * kuu;
			Xl = inXu / af::tile(af::sqrt(scale), ik, 1);
			(*outdL_dParam)(af::seq(1, GetNumParameter() - 1)) = (af::matmulTN(af::constant(1.0f, Ml.dims(0), m_dType), af::matmulTN(Ml, af::pow(Xl, 2))) 
				+ af::matmulTN(af::constant(1.0f, Ml.dims(0), m_dType), af::matmul(Ml, af::pow(Xl, 2)))
				- 2.0 * af::matmulTN(af::constant(1.0f, Xl.dims(0), m_dType), (Xl * af::matmul(Ml, Xl)))).T();

			Xbar = inXu / af::tile(scale, ik, 1);
			Mbar1 = -inM.T() * kuu;
			Mbar2 = -inM * kuu;
			*outdL_dXu = (Xbar * af::tile(af::matmulTN(af::constant(1.0f, Mbar1.dims(0), m_dType), Mbar1), iq, 1).T() - af::matmul(Mbar1, Xbar))
				+ (Xbar * af::tile(af::matmulTN(af::constant(1.0f, Mbar2.dims(0), m_dType), Mbar2), iq, 1).T() - af::matmul(Mbar2, Xbar));

			/**outdL_dXu = (Xbar * tile(sum(Mbar1), iq, 1).T() - matmul(Mbar1, Xbar))
				+ (Xbar * tile(sum(Mbar2), iq, 1).T() - matmul(Mbar2, Xbar));*/

			SetLogParameters(GetLogParameters() / 2.0);
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::SetParameters(const af::array& param)
		{
			dVariance = param(0).scalar<Scalar>();
			dInvScale = param(af::seq(1, GetNumParameter() - 1));
		}

		template<typename Scalar>
		af::array ARDKernel<Scalar>::GetParameters()
		{
			af::array param = af::constant(0.0f, GetNumParameter(), (m_dType));
			param(0) = dVariance;
			param(af::seq(1, GetNumParameter() - 1)) = dInvScale;
			return param;
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::SetLogParameters(const af::array& param)
		{
			af::array tmp = af::exp(param);
			dVariance = tmp(0).scalar<Scalar>();
			dInvScale = tmp(af::seq(1, GetNumParameter() - 1));
		}

		template<typename Scalar>
		af::array ARDKernel<Scalar>::GetLogParameters()
		{
			af::array param = af::constant(0.0f, GetNumParameter(), (m_dType));
			param(0) = dVariance;
			param(af::seq(1, GetNumParameter() - 1)) = dInvScale;
			return log(param);
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::InitParameters(Scalar inMedian)
		{
			af::array param = af::constant(/*exp*/(log(inMedian / 2.0 + 1e-16)), GetNumParameter(), m_dType);
			param(0) = log(0.5);

			SetLogParameters(param);
		}

		////////////////////////////////////////////////////////////////////////////////////////////////////
		/// PSI statistics
		////////////////////////////////////////////////////////////////////////////////////////////////////

		template<typename Scalar>
		void ARDKernel<Scalar>::ComputePsiStatistics(const af::array & inZ, const af::array & inMu, const af::array & inS, af::array & outPsi0, af::array & outPsi1, af::array & outPsi2)
		{
			af::array tempParam = GetLogParameters();
			tempParam(0) *= 2.0;
			SetLogParameters(tempParam);

			outPsi0 = af::constant(dVariance, inMu.dims(0), (m_dType));
			ComputePsi1(inZ, inMu, inS, outPsi1);
			ComputePsi2(inZ, inMu, inS, outPsi2);

			/*std::cout << inMu.dims() << std::endl;
			std::cout << inS.dims() << std::endl;*/
			//std::cout << outPsi2.dims() << std::endl;
			
			tempParam(0) /= 2.0;
			SetLogParameters(tempParam);
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::ComputePsi1(const af::array & inZ, const af::array & inMu, const af::array & inS, af::array & outPsi1)
		{
			af::array tiledLenghtscale2(m_dType), psi1_logdenom(m_dType), psi1_log(m_dType);

			int iN = inS.dims(0);
			int iq = inS.dims(1);
			int ik = inZ.dims(0);

			tiledLenghtscale2 = af::tile(af::pow(dInvScale, 2.0).T(), iN);

			psi1_logdenom = af::sum(af::log(inS / tiledLenghtscale2 + 1.0), 1);
			psi1_log = (af::tile(psi1_logdenom, 1, ik) + af::sum(af::pow(af::tile(af::moddims(inMu, iN, 1, iq), 1, ik) -
				af::tile(af::moddims(inZ, 1, ik, iq), iN), 2.0) * af::tile(af::moddims(1.0 / (inS + tiledLenghtscale2), iN, 1, iq), 1, ik), 2)) / (-2.0);

			outPsi1 = dVariance * af::exp(psi1_log);
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::ComputePsi2(const af::array & inZ, const af::array & inMu, const af::array & inS, af::array & outPsi2)
		{
			af::array lenghtscale2(m_dType), psi2_logdenom(m_dType), psi2_exp1(m_dType), Z_hat(m_dType), denom(m_dType),
				psi2_exp2(m_dType);

			int iN = inS.dims(0);
			int iq = inS.dims(1);
			int ik = inZ.dims(0);

			lenghtscale2 = pow(dInvScale, 2.0);

			psi2_logdenom = af::sum(af::log(2.0 * inS / af::tile(lenghtscale2.T(), iN) + 1.0), 1) / (-2.0);
			psi2_exp1 = af::sum(af::pow(af::tile(af::moddims(inZ, ik, 1, iq), 1, ik) - af::tile(af::moddims(inZ, 1, ik, iq), ik), 2.0) 
				/ af::tile(af::moddims(lenghtscale2, 1, 1, iq), ik, ik), 2) / (-4.0);
			Z_hat = (af::tile(af::moddims(inZ, ik, 1, iq), 1, ik) + af::tile(af::moddims(inZ, 1, ik, iq), ik)) / 2.0;
			denom = 1.0 / (2.0 * inS + af::tile(lenghtscale2.T(), iN));

			psi2_exp2 = -af::tile(af::moddims(af::sum((af::pow(inMu, 2.0) * denom), 1), 1, 1, iN), ik, ik) +
				af::moddims((2.0 * af::matmulNT(inMu * denom, af::moddims(Z_hat, ik * ik, iq)) - af::matmulNT(denom, af::moddims(af::pow(Z_hat, 2.0), ik * ik, iq))).T(), ik, ik, iN);

			outPsi2 = dVariance * dVariance * af::exp(af::tile(af::moddims(psi2_logdenom, 1, 1, iN), ik, ik) + af::tile(psi2_exp1, 1, 1, iN) + psi2_exp2);
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::PsiDerivatives(const af::array& indL_dPsi0, const af::array & inPsi1, const af::array & indL_dPsi1, const af::array & inPsi2, 
			const af::array & indL_dPsi2, const af::array & inXu, const af::array & inMu, const af::array & inS, af::array & outdL_dParam, 
			af::array & outdL_dXu, af::array & outdL_dMu, af::array & outdL_dS, const af::array* dlogZ_dv)
		{
			af::array outdL_dParam1(m_dType), outdL_dXu1(m_dType), outdL_dMu1(m_dType), outdL_dS1(m_dType),
				outdL_dParam2(m_dType), outdL_dXu2(m_dType), outdL_dMu2(m_dType), outdL_dS2(m_dType), tempParam;

			tempParam = GetLogParameters();
			tempParam(0) *= 2.0;
			SetLogParameters(tempParam);

			Psi1Derivative(inPsi1, indL_dPsi1, inXu, inMu, inS, outdL_dParam1, outdL_dXu1, outdL_dMu1, outdL_dS1);
			Psi2Derivative(inPsi2, indL_dPsi2, inXu, inMu, inS, outdL_dParam2, outdL_dXu2, outdL_dMu2, outdL_dS2);

			outdL_dParam = outdL_dParam1 + outdL_dParam2;
			outdL_dXu = outdL_dXu1 + outdL_dXu2;
			outdL_dMu = outdL_dMu1 + outdL_dMu2;
			outdL_dS = outdL_dS1 + outdL_dS2;

			if (dlogZ_dv)
				(outdL_dParam)(0) += (af::sum(af::sum(*dlogZ_dv)));

			(outdL_dParam)(0) *= 2.0f;

			outdL_dParam *= GetParameters();

			tempParam(0) /= 2.0;
			SetLogParameters(tempParam);
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::Psi1Derivative(const af::array& inPsi1, const af::array& indL_dPsi1, const af::array& inZ, const af::array& inMu,
			const af::array& inS, af::array& outdL_dParam, af::array& outdL_dXu, af::array& outdL_dMu, af::array& outdL_dS)
		{
			af::array lenghtscale2(m_dType), Lpsi1(m_dType), Zmu(m_dType), denom(m_dType), Zmu2_denom(m_dType);
			
			outdL_dParam = af::constant(0.0f, GetNumParameter(), (m_dType));
				
			int iN = inS.dims(0);
			int iq = inS.dims(1);
			int ik = inZ.dims(0);

			lenghtscale2 = af::pow(dInvScale, 2.0);

			Lpsi1 = indL_dPsi1 * inPsi1;
			Zmu = af::tile(af::moddims(inZ, 1, ik, iq), iN) - af::tile(af::moddims(inMu, iN, 1, iq), 1, ik); // NxMxQ
			denom = 1.0 / (inS + af::tile(lenghtscale2.T(), iN));
			Zmu2_denom = af::pow(Zmu, 2.0) * af::tile(af::moddims(denom, iN, 1, iq), 1, ik); // #NxMxQ
					
			outdL_dParam(0) = af::sum(af::sum(Lpsi1)) / dVariance;
			outdL_dMu = af::moddims(af::sum(af::tile(Lpsi1, 1, 1, iq) * Zmu * af::tile(af::moddims(denom, iN, 1, iq), 1, ik, 1), 1), iN, iq);
			outdL_dS = af::moddims(af::sum(af::tile(Lpsi1, 1, 1, iq) * (Zmu2_denom - 1.0) * af::tile(af::moddims(denom, iN, 1, iq), 1, ik, 1), 1), iN, iq) / 2.0;
			outdL_dXu = -af::moddims(af::sum(af::tile(Lpsi1, 1, 1, iq) * Zmu * af::tile(af::moddims(denom, iN, 1, iq), 1, ik, 1), 0), ik, iq);
			outdL_dParam(af::seq(1, GetNumParameter() - 1)) = af::moddims(af::sum(af::sum(af::tile(Lpsi1, 1, 1, iq) * (Zmu2_denom + af::tile(af::moddims(inS 
				/ af::tile(lenghtscale2.T(), iN), iN, 1, iq), 1, ik, 1)) * af::tile(af::moddims(denom * af::tile(dInvScale.T(), iN), iN, 1, iq), 1, ik, 1), 0), 1), iq);
		}

		template<typename Scalar>
		void ARDKernel<Scalar>::Psi2Derivative(const af::array & inPsi2, const af::array & indL_dPsi2, const af::array & inXu, const af::array & inMu, const af::array & inS, 
			af::array & outdL_dParam, af::array & outdL_dXu, af::array & outdL_dMu, af::array & outdL_dS)
		{
			int iN = inS.dims(0);
			int iq = inS.dims(1);
			int ik = inXu.dims(0);

			af::array lenghtscale2(m_dType), Lpsi1(m_dType), Zmu(m_dType), denom(m_dType), denom2(m_dType), Zmu2_denom(m_dType),
				dL_dpsi2(m_dType), Lpsi2(m_dType), Lpsi2sum(m_dType), tmp(m_dType), Lpsi2Z(m_dType), Lpsi2Z2(m_dType), 
				Lpsi2Z2p(m_dType), Lpsi2Zhat(m_dType), Lpsi2Zhat2(m_dType), Lpsi2_N(m_dType), Lpsi2_M(m_dType), _dL_dl(m_dType);

			outdL_dParam = af::constant(0.0f, GetNumParameter(), (m_dType));

			lenghtscale2 = af::pow(dInvScale, 2.0);
			denom = 1.0 / (2.0 * inS + af::tile(lenghtscale2.T(), iN));
			denom2 = af::pow(denom, 2.0);

			dL_dpsi2 = (indL_dPsi2 + indL_dPsi2.T()) / 2.0;
			Lpsi2 = dL_dpsi2 * inPsi2;
			Lpsi2sum = af::moddims(af::sum(af::sum(Lpsi2, 0), 1), iN); // N

			tmp = af::moddims(af::moddims(af::matmul(Lpsi2, af::tile(inXu, 1, 1, iN)), ik * iq, iN).T(), iN, ik, iq); // NxKxQ
			Lpsi2Z = af::moddims(af::sum(tmp, 1), iN, iq); // NxQ
			Lpsi2Z2 = af::moddims(af::sum(af::matmul(Lpsi2, af::tile(af::pow(inXu, 2.0), 1, 1, iN)), 0), iq, iN).T(); // NxQ
			Lpsi2Z2p = af::moddims(af::sum(tmp * af::tile(af::moddims(inXu, 1, ik, iq), iN), 1), iN, iq); // NxQ
			Lpsi2Zhat = Lpsi2Z;
			Lpsi2Zhat2 = (Lpsi2Z2 + Lpsi2Z2p) / 2.0;

			outdL_dParam(0) = af::sum(Lpsi2sum) * 2.0 / dVariance;
			outdL_dMu = (-2.0 * denom) * (inMu * af::tile(Lpsi2sum, 1, iq) - Lpsi2Zhat);
			outdL_dS = (2.0 * af::pow(denom, 2.0)) * (af::pow(inMu, 2.0) * af::tile(Lpsi2sum, 1, iq) - 2.0 * inMu * Lpsi2Zhat + Lpsi2Zhat2) - denom * af::tile(Lpsi2sum, 1, iq);
			Lpsi2_N = af::sum(Lpsi2, 2);
			Lpsi2_M = af::moddims(af::sum(Lpsi2, 1), ik, iN).T();
			outdL_dXu = -af::tile(af::sum(Lpsi2_N, 1), 1, iq) * inXu / af::tile(lenghtscale2.T(), ik) + af::matmul(Lpsi2_N, inXu) / af::tile(lenghtscale2.T(), ik) +
				2.0 * af::matmulTN(Lpsi2_M, inMu * denom) - af::matmulTN(Lpsi2_M, denom) * inXu -
				af::moddims(af::sum(af::moddims(af::matmul(af::moddims(Lpsi2, ik * ik, iN), denom), ik, ik, iq) * af::tile(af::moddims(inXu, 1, ik, iq), ik), 1), ik, iq);

			outdL_dParam(af::seq(1, GetNumParameter() - 1)) = 2.0 * dInvScale * af::sum((inS / af::tile(lenghtscale2.T(), iN) * denom +
				af::pow(inMu * denom, 2.0)) * af::tile(Lpsi2sum, 1, iq) + (Lpsi2Z2 - Lpsi2Z2p) / (2.0 * af::tile(af::pow(lenghtscale2.T(), 2.0), iN)) -
				(2.0 * inMu * denom2) * Lpsi2Zhat + denom2 * Lpsi2Zhat2).T();
		}
	}
}