/**
File:		MachineLearning/GPModels/SparseGPModels/AEP/FgAEPSparseGPLayer.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgAEPSparseGPLayer.h>
#include <MachineLearning/FgKMeans.h>

const int MC_NO_SAMPLES = 5;

namespace NeuralEngine::MachineLearning::GPModels::AEP
{
	template class SGPLayer<float>;
	template class SGPLayer<double>;

	

	template<typename Scalar>
	SGPLayer<Scalar>::SGPLayer(int numPoints, int numPseudos, int outputDim, int inputDim)
		: SparseGPBaseLayer<Scalar>(numPoints, numPseudos, outputDim, inputDim)
	{
		afInvSuMu = af::constant(0.0f, ik, 1, iD, m_dType);
		afInvSu = af::constant(0.0f, ik, ik, iD, m_dType);
		afSu = af::constant(0.0f, ik, ik, iD, m_dType);

		afSuMuMuHat = af::constant(0.0f, ik, ik, iD, m_dType);
		afSuMuMu = af::constant(0.0f, ik, ik, iD, m_dType);
		afBetaHatStochastic = af::constant(0.0f, ik, ik, iD, m_dType);
		afBetaStochastic = af::constant(0.0f, ik, ik, iD, m_dType);
	}

	template<typename Scalar>
	SGPLayer<Scalar>::~SGPLayer()
	{
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::ForwardPredictionCavity(af::array& mout, af::array& vout, af::array* psi1out, af::array* psi2out, const af::array& mx, const af::array* vx, Scalar alpha, PropagationMode mode)
	{
		if (vx == nullptr)
			ForwardPredictionDeterministicCavity(mout, vout, psi1out, mx, alpha);
		else
			ForwardPredictionRandomCavity(mout, vout, psi1out, psi2out, mx, *vx, mode, alpha);
	}

	/*void SGPLayer<Scalar>::ForwardPredictionPost(af::array& mout, af::array& vout, af::array* mx, af::array* vx)
	{
	}*/

	template<typename Scalar>
	af::array SGPLayer<Scalar>::BackpropGradientsReg(const af::array& m, const af::array& v, const af::array& dlogZ_dm, const af::array& dlogZ_dv, const af::array& x, 
		std::map<std::string, af::array>* grad_cav, Scalar alpha)
	{
		int batchN = x.dims(0);

		Scalar beta = (iN - alpha) * 1.0 / iN;
		Scalar scale_post = iN * 1.0 / alpha - 1.0;
		Scalar scale_cav = -iN * 1.0 / alpha;
		Scalar scale_prior = 1.0f;

		int iStart = 0, iEnd = 0;

		af::array paramTemp, tiledInvKuu, diagdlogZ_dv, dGammaHat, dBetaHat, dvcav, dmcav, dL_dkfu, dL_dParamKfu, dL_dParamKuu, dL_dXuKuu, dL_dXuKfu,
			kfuInvKuu, dmucav, dSucav, dmu, dSu, KuuinvSmmd, dKuuinv, dKuuinvPrior, dInvKuuPost, dInvKuuCav, dKuuinv_m, dKuuinv_v, dT1cav,
			dT2cav, dT1, dT2, M_inner, dL_dXu, dK_dX, dL_dParam, tiledKfu, dlogZ_dmMod;
			
		af::array grad_hyper = af::constant(0.0, GetNumParameters(), m_dType);

		ComputeKfu(x);

		dlogZ_dmMod = af::moddims(dlogZ_dm, batchN, 1, iD);
		diagdlogZ_dv = af::diag(dlogZ_dv, 0, false);

		// tile matrices per dimension for faster parallel computation
		kfuInvKuu = af::tile(af::matmul(afKfu, afInvKuu), 1, 1, iD);
		tiledKfu = af::tile(afKfu, 1, 1, iD);
		tiledInvKuu = af::tile(afInvKuu, 1, 1, iD);
		
		// compute gradients w.r.t. GammaHat and BetaHat
		dGammaHat = af::matmulTN(dlogZ_dmMod, tiledKfu).T();
		dBetaHat = af::matmul(af::matmulTN(tiledKfu, diagdlogZ_dv), tiledKfu);

		// compute gradients w.r.t. m and v
		dvcav = af::matmul(tiledInvKuu, af::matmul(dBetaHat, tiledInvKuu));
		dmcav = af::matmul(tiledInvKuu, dGammaHat);

		// compute gradients w.r.t. kfu
		dL_dkfu = af::sum(af::matmulNT(dlogZ_dmMod, afGammaHat), 2); // dkuf_m contribution
		dL_dkfu += 2.0f * af::sum(af::matmul(af::matmul(diagdlogZ_dv, tiledKfu), afBetaHat), 2); // dkfu = dkuf_m + dkuf_v
		
		kernel->LogLikGradientCompundKfu(dL_dkfu, x, afXu, &dL_dParamKfu, &dL_dXuKfu, &dlogZ_dv);
		//dL_dParamKfu *= kernel->GetParameters(); // consider log space

		// compute grads wrt cavity mean and covariance
		dmucav = af::matmulTN(dlogZ_dmMod, kfuInvKuu).T();
		dSucav = af::matmul(af::matmulTN(kfuInvKuu, diagdlogZ_dv), kfuInvKuu);

		// add in contribution from the normalizing factor
		dmucav += (scale_cav * afInvSuMuHat);
		dSucav += scale_cav * (0.5 * afInvSuHat - 0.5 * af::matmulNT(afInvSuMuHat, afInvSuMuHat));
		
		ComputeCavityGradientU(dmucav, dSucav, dT1cav, dT2cav, dInvKuuCav, alpha);
		
		// compute grads wrt posterior mean and covariance
		dmu = scale_post * afInvSuMu;
		dSu = scale_post * (0.5 * afInvSu - 0.5 * af::matmulNT(afInvSuMu, afInvSuMu));
		
		ComputePosteriorGradientU(dmu, dSu, dT1, dT2, dInvKuuPost);
	
		// contribution from phi prior term
		dKuuinvPrior = -0.5 * iD * afKuu;
		dT1 += dT1cav;
		dT2 += dT2cav;
		dKuuinv = dKuuinvPrior + dInvKuuPost + dInvKuuCav;
		
		// get contribution of Ahat and Bhat to Kuu and add to Minner
		dKuuinv_m = af::sum(af::matmulNT(dGammaHat, afMuHat), 2);
		KuuinvSmmd = af::matmul(tiledInvKuu, afSuHat);
		dKuuinv_v = 2.0f * af::sum(af::matmulTN(KuuinvSmmd, dBetaHat), 2) - af::sum(dBetaHat, 2);
		dKuuinv += dKuuinv_m + dKuuinv_v;
		
		M_inner = -af::matmul(afInvKuu, af::matmul(dKuuinv, afInvKuu));
		kernel->LogGradientCompoundKuu(afXu, M_inner, &dL_dParamKuu, &dL_dXuKuu);

		dL_dParam = dL_dParamKfu + 2 * dL_dParamKuu;
		dL_dXu = dL_dXuKfu + dL_dXuKuu;

		// Collecting gradients
		if (!isFixedHypers)
		{
			iEnd = GetKernel()->GetNumParameter();
			grad_hyper(af::seq(iStart, iEnd - 1)) = dL_dParam;
		}

		if (mStyles)
		{
			int sStart = iq, sEnd = iq;
			af::array gradStyleFlat;
			for (auto style = mStyles->begin(); style != mStyles->end(); style++)
			{
				sStart = sEnd; sEnd += style->second.GetNumSubstyles();
				gradStyleFlat = CommonUtil<Scalar>::Join(gradStyleFlat, style->second.GetInducingGradientCollapsed(dL_dXu(af::span, af::seq(sStart, sEnd - 1))));
			}

			iStart = iEnd;
			iEnd += gradStyleFlat.dims(0);
			grad_hyper(af::seq(iStart, iEnd - 1)) = gradStyleFlat;
		}

		iStart = iEnd;
		iEnd += ik * iD;
		grad_hyper(af::seq(iStart, iEnd - 1)) = af::flat(dT1);

		iStart = iEnd;
		iEnd += (ik * (ik + 1) / 2) * iD;
		grad_hyper(af::seq(iStart, iEnd - 1)) = af::flat(dT2);

		if (!isFixedInducing)
		{
			/*iStart = iEnd;
			iEnd += ik * iq;
			grad_hyper(af::seq(iStart, iEnd - 1)) = af::flat(dL_dXu);*/

			iStart = iEnd;
			iEnd += ik * iq;
			grad_hyper(af::seq(iStart, iEnd - 1)) = af::flat(dL_dXu(af::span, af::seq(0, iq - 1)));
		}

		if (grad_cav != nullptr)
		{
			grad_cav->clear();
			grad_cav->insert(std::pair<std::string, af::array>("dmcav", dmcav));
			grad_cav->insert(std::pair<std::string, af::array>("dvcav", dvcav));
		}

		return grad_hyper;
	}

	template<typename Scalar>
	af::array SGPLayer<Scalar>::BackpropGradientsMM(const af::array & m, const af::array & v, const af::array & dlogZ_dm, const af::array & dlogZ_dv, 
		const af::array & psi1, const af::array & psi2, const af::array & mx, const af::array & vx, std::map<std::string, af::array>* outGradInput, Scalar alpha)
	{
		int batchN = mx.dims(0);

		Scalar beta = (iN - alpha) * 1.0 / iN;
		Scalar scale_post = iN * 1.0 / alpha - 1.0;
		Scalar scale_cav = -iN * 1.0 / alpha;
		Scalar scale_prior = 1.0f;

		int iStart = 0, iEnd = 0;

		af::array triIdx = CommonUtil<Scalar>::TriUpperIdx(ik);
		af::array diagIdx = CommonUtil<Scalar>::DiagIdx(ik);

		af::array dmAll, dGammaHat, dBetaHat, dPsi1, dPsi2, dPsi0, dL_dParam, dL_dXu, dL_dmx, dL_dvx, paramTemp, tiledInvKuu, dmcav, dvcav, dT1,
			dT2, dvcavinv, dKuuinv_via_vcav, dKuuinv_via_GammaHat, KuuinvSmmd, dKuuinv_via_Bhat, dKuuinv, Minner, dT2_R, out_dT2, dT2_Rd, 
			T2_Rd, M_all, dL_dParamKuu, dL_dXuKuu;

		af::array grad_hyper = af::constant(0.0, GetNumParameters(), m_dType);

		// tile matrices per dimension for faster parallel computation
		tiledInvKuu = af::tile(afInvKuu, 1, 1, iD);

		// compute grads wrt GammaHat and BetaHat
		dmAll = dlogZ_dm - 2.0 * dlogZ_dv * m;
		dGammaHat = af::matmulTN(af::moddims(dmAll, batchN, 1, iD), af::tile(psi1, 1, 1, iD)).T();
		//dBetaHat = moddims(matmul(dlogZ_dv.T(), moddims(psi2, ik * ik, batchN).T()).T(), ik, ik, iD);
		dBetaHat = af::moddims(af::sum(af::tile(af::moddims(dlogZ_dv, batchN, 1, iD), 1, ik * ik, 1) 
			* af::tile(af::moddims(psi2, ik * ik, batchN).T(), 1, 1, iD), 0), ik, ik, iD);

		// compute grads wrt Psi1 and Psi2
		dPsi0 = af::constant(.5 * iD, iN, m_dType);
		dPsi1 = af::matmulNT(dmAll, af::moddims(afGammaHat, ik, iD));
		dPsi2 = af::sum(af::tile(af::moddims(dlogZ_dv, 1, 1, batchN, iD), ik, ik, 1, 1) 
			* af::tile(af::moddims(afBetaHatStochastic, ik, ik, 1, iD), 1, 1, batchN, 1), 3);

		kernel->PsiDerivatives(dPsi0, psi1, dPsi1, psi2, dPsi2, afXu, mx, vx, dL_dParam, dL_dXu, dL_dmx, dL_dvx, &dlogZ_dv);
		//dL_dParam *= kernel->GetParameters(); // consider log space
		
		// compute grads wrt cavity mean and covariance
		dvcav = af::matmul(af::matmul(tiledInvKuu, dBetaHat), tiledInvKuu);
		dmcav = 2.0 * af::matmul(dvcav, afMuHat) + af::matmul(tiledInvKuu, dGammaHat);
		dvcav += beta * af::matmulNT(dmcav, T1);
		dvcavinv = -af::matmul(af::matmul(afSuHat, dvcav), afSuHat);
		dT2 = beta * dvcavinv;
		dT1 = beta * af::matmul(afSuHat, dmcav);
		dKuuinv_via_vcav = af::sum(dvcavinv, 2);

		// get contribution of GammaHat and BetaHat to Kuu and add to Minner
		dKuuinv_via_GammaHat = af::matmulNT(af::moddims(dGammaHat, ik, iD), af::moddims(afMuHat, ik, iD));
		KuuinvSmmd = af::matmul(tiledInvKuu, afSuMuMuHat);
		dKuuinv_via_Bhat = 2.0 * af::matmulNT(af::moddims(KuuinvSmmd.T(), ik, ik * iD), af::moddims(dBetaHat.T(), ik, ik * iD)) - af::sum(dBetaHat, 2);
		dKuuinv = dKuuinv_via_GammaHat + dKuuinv_via_Bhat + dKuuinv_via_vcav;
		Minner = scale_post * af::sum(afSuMuMu, 2) + scale_cav * af::sum(afSuMuMuHat, 2) - 2.0 * dKuuinv;
		dT2 = -0.5 * scale_post * afSuMuMu - 0.5 * scale_cav * beta * afSuMuMuHat + dT2;
		dT1 = scale_post * afMu + scale_cav * beta * afMuHat + dT1;

		dT2_R = af::matmul(T2_R, dT2 + dT2.T());
		out_dT2 = af::constant(0.0f, iD, ik * (ik + 1) / 2, m_dType);
		for (int d = 0; d < iD; d++)
		{
			dT2_Rd = dT2_R(af::span, af::span, d);
			T2_Rd = T2_R(af::span, af::span, d);
			dT2_Rd(diagIdx) = dT2_Rd(diagIdx) * T2_Rd(diagIdx);
			out_dT2(d, af::span) = dT2_Rd(triIdx);
		}

		M_all = 0.5 * (scale_prior * iD * afInvKuu + af::matmul(afInvKuu, af::matmul(Minner, afInvKuu)));
		kernel->LogGradientCompoundKuu(afXu, M_all, &dL_dParamKuu, &dL_dXuKuu);

		dL_dParam += 2 * dL_dParamKuu;
		dL_dXu += dL_dXuKuu;

		// Collecting gradients
		if (!isFixedHypers)
		{
			iEnd = GetKernel()->GetNumParameter();
			grad_hyper(af::seq(iStart, iEnd - 1)) = dL_dParam;
		}

		if (mStyles)
		{
			int sStart = iq, sEnd = iq;
			af::array gradStyleFlat;
			for (auto style = mStyles->begin(); style != mStyles->end(); style++)
			{
				sStart = sEnd; sEnd += style->second.GetNumSubstyles();
				gradStyleFlat = CommonUtil<Scalar>::Join(gradStyleFlat, style->second.GetInducingGradientCollapsed(dL_dXu(af::span, af::seq(sStart, sEnd - 1))));
			}

			iStart = iEnd;
			iEnd += gradStyleFlat.dims(0);
			grad_hyper(af::seq(iStart, iEnd - 1)) = gradStyleFlat;
		}

		iStart = iEnd;
		iEnd += ik * iD;
		grad_hyper(af::seq(iStart, iEnd - 1)) = af::flat(dT1);

		iStart = iEnd;
		iEnd += (ik * (ik + 1) / 2) * iD;

		grad_hyper(af::seq(iStart, iEnd - 1)) = af::flat(out_dT2);

		if (!isFixedInducing)
		{
			iStart = iEnd;
			iEnd += ik * iq;
			grad_hyper(af::seq(iStart, iEnd - 1)) = af::flat(dL_dXu(af::span, af::seq(0, iq - 1)));
		}

		if (outGradInput != nullptr)
		{
			outGradInput->clear();
			outGradInput->insert(std::pair<std::string, af::array>("dL_dmx", dL_dmx));
			outGradInput->insert(std::pair<std::string, af::array>("dL_dvx", dL_dvx));
		}

		return grad_hyper;
	}

	template<typename Scalar>
	af::array SGPLayer<Scalar>::BackpropGradientsMC(const af::array& mcav, const af::array& vcav, const af::array& eps, const af::array& dlogZ_dm, const af::array& dlogZ_dv, const af::array& x, std::map<std::string, af::array>* outGradCav, Scalar alpha)
	{
		int batchN = x.dims(1);

		Scalar beta = (iN - alpha) * 1.0 / iN;
		Scalar scale_post = iN * 1.0 / alpha - 1.0;
		Scalar scale_cav = -iN * 1.0 / alpha;
		Scalar scale_prior = 1.0f;

		int iStart = 0, iEnd = 0;

		int styleDim = 0;

		if (mStyles)
			for (auto style = mStyles->begin(); style != mStyles->end(); style++)
				styleDim += style->second.GetNumSubstyles();

		af::array x_stk = af::moddims(x.T(), MC_NO_SAMPLES * batchN, iq + styleDim);

		af::array dlogZ_dvstk = af::moddims(dlogZ_dv.T(), MC_NO_SAMPLES * batchN, iD);

		ComputeKfu(x_stk);

		af::array tiledKfu = af::tile(afKfu, 1, 1, iD);
		af::array tiledInvKuu = af::tile(afInvKuu, 1, 1, iD);

		af::array dlogZ_dmMod = af::moddims(dlogZ_dm.T(), MC_NO_SAMPLES * batchN, 1, iD);
		af::array diagdlogZ_dv = af::diag(af::moddims(dlogZ_dv.T(), MC_NO_SAMPLES * batchN, iD), 0, false);

		// compute gradients w.r.t. kfu
		af::array dL_dkfu = af::sum(af::matmulNT(dlogZ_dmMod, afGammaHat), 2); // dkuf_m contribution
		dL_dkfu += 2.0f * af::sum(af::matmul(af::matmul(diagdlogZ_dv, tiledKfu), afBetaHat), 2); // dkfu = dkuf_m + dkuf_v

		af::array dL_dParamKfu, dL_dXuKfu, dL_dX;
		kernel->LogLikGradientCompundKfu(dL_dkfu, x_stk, afXu, &dL_dParamKfu, &dL_dXuKfu, &dlogZ_dvstk, &dL_dX);

		// compute grads w.r.t. T1 and T2
		af::array kfuInvKuu = af::tile(af::matmul(afKfu, afInvKuu), 1, 1, iD);
		af::array SKuuinvKuf = af::matmul(kfuInvKuu, afSuHat);
		af::array dSinv_via_v = -af::matmul(af::matmulTN(SKuuinvKuf, diagdlogZ_dv), SKuuinvKuf);
		af::array dSinv_via_m = -af::matmulTT(af::matmulTN(dlogZ_dmMod, SKuuinvKuf), afMuHat);
		af::array dSinv = dSinv_via_m + dSinv_via_v;

		af::array dSinvM = matmulTN(SKuuinvKuf, dlogZ_dmMod);
		af::array dT2 = beta * dSinv;
		af::array dT1 = beta * dSinvM;

		dT2 += -0.5 * scale_post * afSuMuMu - 0.5 * scale_cav * beta * afSuMuMuHat;
		dT1 += scale_post * afMu + scale_cav * beta * afMuHat;

		af::array dT2_R = af::matmul(T2_R, dT2 + dT2.T());

		af::array detaT2 = af::constant(0.0f, iD, ik * (ik + 1) / 2);
		af::array diagIdx = CommonUtil<Scalar>::DiagIdx(ik);
		af::array triIdx = CommonUtil<Scalar>::TriUpperIdx(ik);
		for (int d = 0; d < iD; d++)
		{
			af::array dT2_Rd = dT2_R(af::span, af::span, d);
			af::array T2_Rd = T2_R(af::span, af::span, d);
			dT2_Rd(diagIdx) = dT2_Rd(diagIdx) * T2_Rd(diagIdx);
			detaT2(d, af::span) = dT2_Rd(triIdx);
			
		}

		// get contribution of Ahat and Bhat to Kuu and add to Minner
		// compute gradients w.r.t. GammaHat and BetaHat
		af::array dGammaHat = af::matmulTN(dlogZ_dmMod, tiledKfu).T();
		af::array dBetaHat = af::matmul(af::matmulTN(tiledKfu, diagdlogZ_dv), tiledKfu);

		af::array dKuuinv_m = af::sum(af::matmulNT(dGammaHat, afMuHat), 2);
		af::array KuuinvSmmd = af::matmul(tiledInvKuu, afSuHat);
		af::array dKuuinv_v1 = 2.0f * af::sum(af::matmulTN(KuuinvSmmd, dBetaHat), 2) - af::sum(dBetaHat, 2);
		af::array dKuuinv_v2 = af::sum(dSinv, 2);

		af::array dKuuinv = dKuuinv_m + dKuuinv_v1 + dKuuinv_v2;

		af::array Minner = scale_post * af::sum(afSuMuMu, 2) + scale_cav * af::sum(afSuMuMuHat, 2)  - 2.0 * dKuuinv;
		af::array M_all = 0.5 * (scale_prior * iD * afInvKuu + matmul(afInvKuu, matmul(Minner, afInvKuu)));

		af::array dL_dParamKuu, dL_dXuKuu;
		kernel->LogGradientCompoundKuu(afXu, M_all, &dL_dParamKuu, &dL_dXuKuu);

		af::array dL_dParam = dL_dParamKfu + 2 * dL_dParamKuu;
		af::array dL_dXu = dL_dXuKfu + dL_dXuKuu;

		// Collecting gradients
		af::array grad_hyper = af::constant(0.0, GetNumParameters(), m_dType);
		if (!isFixedHypers)
		{
			iEnd = GetKernel()->GetNumParameter();
			grad_hyper(af::seq(iStart, iEnd - 1)) = dL_dParam;
		}

		if (mStyles)
		{
			int sStart = iq, sEnd = iq;
			af::array gradStyleFlat;
			for (auto style = mStyles->begin(); style != mStyles->end(); style++)
			{
				sStart = sEnd; sEnd += style->second.GetNumSubstyles();
				gradStyleFlat = CommonUtil<Scalar>::Join(gradStyleFlat, style->second.GetInducingGradientCollapsed(dL_dXu(af::span, af::seq(sStart, sEnd - 1))));
			}

			iStart = iEnd;
			iEnd += gradStyleFlat.dims(0);
			grad_hyper(af::seq(iStart, iEnd - 1)) = gradStyleFlat;
		}

		//af_print(gradHyper["dXu"]);

		iStart = iEnd;
		iEnd += ik * iD;
		grad_hyper(af::seq(iStart, iEnd - 1)) = af::flat(dT1);

		iStart = iEnd;
		iEnd += (ik * (ik + 1) / 2) * iD;
		grad_hyper(af::seq(iStart, iEnd - 1)) = af::flat(detaT2);

		if (!isFixedInducing)
		{
			iStart = iEnd;
			iEnd += ik * iq;
			grad_hyper(af::seq(iStart, iEnd - 1)) = af::flat(dL_dXu(af::span, af::seq(0, iq - 1)));
		}

		if (outGradCav != nullptr)
		{
			//af::array dx = af::moddims(dL_dX, MC_NO_SAMPLES, batchN, iq);
			af::array dmcav;
			af::array dvcav;

			if (mcav.dims(0) == MC_NO_SAMPLES)
			{
				dL_dX = af::moddims(dL_dX, batchN, MC_NO_SAMPLES, iq + styleDim).T();
				dmcav = dL_dX;
				dvcav = (dL_dX * eps) / (2.0 * af::sqrt(vcav));
			}
			else
			{
				dL_dX = af::moddims(dL_dX, batchN, MC_NO_SAMPLES, iq + styleDim).T();
				dmcav = moddims(af::sum(dL_dX, 0), batchN, iq + styleDim);
				dvcav = af::moddims(af::sum(dL_dX * eps, 0), batchN, iq + styleDim) / (2.0 * af::sqrt(vcav));
			}

			outGradCav->clear();
			outGradCav->insert(std::pair<std::string, af::array>("dL_dmx", dmcav));
			outGradCav->insert(std::pair<std::string, af::array>("dL_dvx", dvcav));
			//outGradCav->insert(std::pair<std::string, af::array>("dvcav", dx));
		}

		return grad_hyper;
	}

	template<typename Scalar>
	Scalar SGPLayer<Scalar>::ComputePhi(Scalar alpha)
	{
		Scalar scale_post = iN * 1.0f / alpha - 1.0f;
		Scalar scale_cav = -iN * 1.0f / alpha;
		Scalar scale_prior = 1.0f;

		Scalar phi_prior = ComputePhiPrior();
		Scalar phi_post = ComputePhiPosterior();
		Scalar phi_cav = ComputePhiCavity();

		Scalar phi = scale_prior * phi_prior + scale_post * phi_post + scale_cav * phi_cav;
		return phi;
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::UpdateFactor(af::array & n, std::map<std::string, af::array> grad_cav, Scalar alpha, Scalar decay)
	{
		int numSamples = n.dims(0);
		af::array munew(m_dType), inner(m_dType), Sunew(m_dType), Suinvnew(m_dType), SuinvMunew(m_dType), t1_frac(m_dType), 
			t2_frac(m_dType), t1_old(m_dType), t2_old(m_dType), t1_new(m_dType), t2_new(m_dType);
		af::array dmcav = grad_cav["dmcav"];
		af::array dvcav = grad_cav["dvcav"];

		// perform PowerEP update
		munew = afMuHat + af::sum(afSuHat * af::tile(af::moddims(dmcav, numSamples, iD, 1, ik), 1, 1, ik, 1), 3);

		Sunew = af::constant(0.0f, numSamples, iD, ik, ik);
		Suinvnew = af::constant(0.0f, numSamples, iD, ik, ik);
		for (int n = 0; n < numSamples; n++)
		{
			for (int d = 0; d < iD; d++)
			{
				inner = af::matmulTN(af::moddims(dmcav(n, d, af::span), 1, ik), af::moddims(dmcav(n, d, af::span), 1, ik)) - 2 
					* af::moddims(dvcav(n, d, af::span, af::span), ik, ik, 1, 1);
				inner = af::matmul(inner, af::moddims(afSuHat(n, d, af::span, af::span), ik, ik, 1, 1));
				Sunew(n, d, af::span, af::span) = af::moddims(af::matmul(af::moddims(afSuHat(n, d, af::span, af::span), ik, ik, 1, 1), inner), 1, 1, ik, ik);
				Sunew(n, d, af::span, af::span) = afSuHat(n, d, af::span, af::span) - Sunew(n, d, af::span, af::span);

				Suinvnew(n, d, af::span, af::span) = af::moddims(af::inverse(af::moddims(Sunew(n, d, af::span, af::span), ik, ik, 1, 1)), 1, 1, ik, ik);
			}
		}

		SuinvMunew = af::sum(Suinvnew * af::tile(af::moddims(munew, numSamples, iD, 1, ik), 1, 1, ik, 1), 3);

		t2_frac = Suinvnew - afInvSuHat;
		t1_frac = SuinvMunew - afInvSuMuHat;
		t1_old = T1(n, af::span, af::span);
		t2_old = T2(n, af::span, af::span, af::span);
		t1_new = (1.0 - alpha) * t1_old + t1_frac;
		t2_new = (1.0 - alpha) * t2_old + t2_frac;

		if (numSamples == 1)
		{
			// sequential update
			T1(n, af::span, af::span) = t1_new;
			T2(n, af::span, af::span, af::span) = t2_new;

			afMu = munew(0, af::span, af::span);
			afSu = Sunew(0, af::span, af::span, af::span);

			afInvSuMu = SuinvMunew(0, af::span, af::span);
			afInvSu = Suinvnew(0, af::span, af::span, af::span);
		}
		else
		{
			// parallel update
			T1(n, af::span, af::span) = decay * t1_old + (1 - decay) * t1_new;
			T2(n, af::span, af::span, af::span) = decay * t2_old + (1 - decay) * t2_new;
			UpdateParameters();
		}
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::ForwardPredictionDeterministicCavity(af::array& mout, af::array& vout, af::array* kfuOut, const af::array& mx, Scalar alpha)
	{
		int numSamples = mx.dims(0);
		af::array kff(m_dType), kfu(m_dType);
		
		ComputeCavity(alpha);
		
		//kernel->SetLogParameters(kernel->GetLogParameters() * 2.0);
		kernel->ComputeDiagonal(mx, kff);
		kernel->ComputeKernelMatrix(mx, afXu, kfu);
		//kernel->SetLogParameters(kernel->GetLogParameters() / 2.0);

		kff = af::tile(kff, 1, 1, iD);
		kfu = af::tile(kfu, 1, 1, iD);
		
		mout = af::moddims(af::matmul(kfu, afGammaHat), numSamples, iD);
		vout = (af::moddims(kff + af::diag(af::matmulNT(af::matmul(kfu, afBetaHat), kfu)), numSamples, iD));

		if (kfuOut != nullptr) *kfuOut = kfu.copy();
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::ForwardPredictionRandomCavity(af::array& mout, af::array& vout, af::array* psi1out, af::array* psi2out, const af::array& mx, const af::array& vx, PropagationMode mode, Scalar alpha)
	{
		switch (mode)
		{
		case PropagationMode::MomentMatching:
			ForwardPredictionRandomCavityMM(mout, vout, psi1out, psi2out, mx, vx, alpha);
			break;
		case PropagationMode::Linear:
			break;
		//case PropagationMode::MonteCarlo:
		//	ForwardPredictionRandomCavityMC(mout, vout, psi1out, psi2out, mx, vx, alpha);
		//	break;
		}
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::ForwardPredictionRandomCavityMM(af::array & mout, af::array & vout, af::array* psi1out, af::array* psi2out, const af::array & mx, const af::array & vx, Scalar alpha)
	{
		af::array psi0(m_dType), Bhatpsi2(m_dType), tmp(m_dType), paramTemp(m_dType);

		ComputeCavity(alpha);

		/*paramTemp = kernel->GetLogParameters();
		paramTemp(0) = 2 * paramTemp(0);
		kernel->SetLogParameters(paramTemp);*/

		kernel->ComputePsiStatistics(afXu, mx, vx, psi0, *psi1out, *psi2out);
		
		/*paramTemp(0) = 0.5 * paramTemp(0);
		kernel->SetLogParameters(paramTemp);*/

		mout = af::moddims(af::matmul(af::tile(*psi1out, 1, 1, iD), afGammaHat), mx.dims(0), iD);

		Bhatpsi2 = af::constant(0.0, mx.dims(0), 1, iD, m_dType);
		for (uint i = 0; i < mx.dims(0); i++)
		{
			tmp = (*psi2out)(af::span, af::span, i);
			Bhatpsi2(i, 0, af::span) = af::sum(af::sum(af::tile(tmp, 1, 1, iD) * afBetaHatStochastic));
		}

		vout = af::moddims(af::tile(psi0, 1, 1, iD) + Bhatpsi2 - af::moddims(af::pow(mout, 2), mx.dims(0), 1, iD), mx.dims(0), iD);
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::ForwardPredictionRandomCavityMC(af::array& mout, af::array& vout, af::array& xout, af::array& eps, const af::array& mx, const af::array& vx, Scalar alpha)
	{
		int batch_size, styleDim = 0;

		if (mStyles)
			for (auto style = mStyles->begin(); style != mStyles->end(); style++)
				styleDim += style->second.GetNumSubstyles();

		/*batch_size = (int)mx.dims(0);
		eps = af::constant(0.25, MC_NO_SAMPLES, batch_size, iq, m_dType);
		xout = eps * af::tile(af::moddims(af::sqrt(vx), 1, batch_size, iq), MC_NO_SAMPLES) + af::tile(af::moddims(mx, 1, batch_size, iq), MC_NO_SAMPLES);*/
		
		if (mx.dims(0) == MC_NO_SAMPLES)
		{
			batch_size = (int)mx.dims(1);
			eps = af::randn(MC_NO_SAMPLES, batch_size, iq + styleDim, m_dType) * 0.1;
			xout = eps * af::sqrt(vx) + mx;
		}
		else
		{
			batch_size = (int)mx.dims(0);
			eps = af::randn(MC_NO_SAMPLES, batch_size, iq + styleDim, m_dType) * 0.1;
			xout = eps * af::tile(af::moddims(af::sqrt(vx), 1, batch_size, iq + styleDim), MC_NO_SAMPLES) + af::tile(af::moddims(mx, 1, batch_size, iq + styleDim), MC_NO_SAMPLES);
		}
		af::array x_stk = af::moddims(xout.T(), MC_NO_SAMPLES * batch_size, iq + styleDim);
		//af::array e_stk = af::moddims(eps, MC_NO_SAMPLES * batch_size, iq);

		ForwardPredictionDeterministicCavity(mout, vout, nullptr, x_stk, alpha);

		mout = af::moddims(mout, batch_size, MC_NO_SAMPLES, iD).T();
		vout = af::moddims(vout, batch_size, MC_NO_SAMPLES, iD).T();
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::ForwardPredictionDeterministicPost(const af::array& mx, af::array* mout, af::array* vout)
	{
		int numSamples = mx.dims(0);
		af::array kff(m_dType), kfu(m_dType);

		//kernel->SetLogParameters(kernel->GetLogParameters() * 2.0);
		kernel->ComputeDiagonal(mx, kff);
		kernel->ComputeKernelMatrix(mx, afXu, kfu);
		//kernel->SetLogParameters(kernel->GetLogParameters() / 2.0);
		

		kff = af::tile(kff, 1, 1, iD);
		kfu = af::tile(kfu, 1, 1, iD);

		*mout = af::moddims(af::matmul(kfu, afGamma), numSamples, iD);
		*vout = (af::moddims(kff + af::diag(af::matmulNT(af::matmul(kfu, afBeta), kfu)), numSamples, iD));
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::ForwardPredictionRandomPost(const af::array & mx, const af::array & vx, af::array & mout, af::array & vout, PropagationMode mode)
	{
		switch (mode)
		{
		case PropagationMode::MomentMatching:
			ForwardPredictionRandomPostMM(mx, vx, mout, vout);
			break;
		case PropagationMode::Linear:
			break;
		case PropagationMode::MonteCarlo:
			break;
		}
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::ForwardPredictionRandomPostMM(const af::array & mx, const af::array & vx, af::array & mout, af::array & vout)
	{
		af::array psi0(m_dType), psi1(m_dType), psi2(m_dType), Bpsi2(m_dType), tmp(m_dType), paramTemp(m_dType);
		int numSamples = mx.dims(0);

		/*paramTemp = kernel->GetLogParameters();
		paramTemp(0) = 2 * paramTemp(0);
		kernel->SetLogParameters(paramTemp);*/

		kernel->ComputePsiStatistics(afXu, mx, vx, psi0, psi1, psi2);

		/*paramTemp(0) = 0.5 * paramTemp(0);
		kernel->SetLogParameters(paramTemp);*/

		//psi0 = exp(2 * log(psi0));

		mout = af::moddims(af::matmul(af::tile(psi1, 1, 1, iD), afGamma), mx.dims(0), iD);

		Bpsi2 = af::constant(0.0, mx.dims(0), 1, iD, m_dType);
		for (uint i = 0; i < mx.dims(0); i++)
		{
			tmp = (psi2)(af::span, af::span, i);
			Bpsi2(i, 0, af::span) = af::sum(af::sum(af::tile(tmp, 1, 1, iD) * afBetaStochastic));
		}

		vout = af::moddims(af::tile(psi0, 1, 1, iD) + Bpsi2 - af::moddims(af::pow(mout, 2), mx.dims(0), 1, iD), mx.dims(0), iD);
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::ComputeCavity(Scalar alpha)
	{
		Scalar beta = (iN - alpha) / iN;
		af::array tiledInvKuu = af::tile(afInvKuu, 1, 1, iD);

		afInvSuMuHat = beta * T1;
		afInvSuHat = tiledInvKuu + beta * T2;

		afSuHat = af::constant(0.0f, ik, ik, iD, m_dType);
		
		for(int ii = 0; ii < iD; ii++)
		{
			try { afSuHat(af::span, af::span, ii) = af::inverse(afInvSuHat(af::span, af::span, ii)); }
			catch (...) { afSuHat(af::span, af::span, ii) = CommonUtil<Scalar>::PDInverse(afInvSuHat(af::span, af::span, ii)); }
		}

		//afSuHat = inverse(afInvSuHat);
		afMuHat = af::matmul(afSuHat, afInvSuMuHat);

		afGammaHat = af::matmul(tiledInvKuu, afMuHat);
		afBetaHat = -tiledInvKuu + af::matmul(af::matmul(tiledInvKuu, afSuHat), tiledInvKuu);

		afSuMuMuHat = afSuHat + af::matmulNT(afMuHat, afMuHat);
		afBetaHatStochastic = -tiledInvKuu + af::matmul(tiledInvKuu, af::matmul(afSuMuMuHat, tiledInvKuu));
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::ComputeCavityGradientU(af::array& dMucav, af::array& dSucav, af::array& out_dT1, af::array& out_dT2, af::array& out_dInvKuu, Scalar alpha)
	{
		Scalar beta = (iN - alpha) * 1.0 / iN;

		af::array dSu_via_m(m_dType), dInvSucav(m_dType), dT2_R(m_dType), dT2(m_dType), 
			dT2_Rd(m_dType), T2_Rd(m_dType), diagIdx(m_dType), triIdx(m_dType);
		
		dSu_via_m = af::matmulNT(dMucav, beta * T1);
		dSucav += dSu_via_m;

		dInvSucav = -af::matmul(af::matmul(afSuHat, dSucav), afSuHat);

		out_dInvKuu = af::sum(dInvSucav, 2);

		out_dT1 = beta * af::matmul(afSuHat, dMucav);
		dT2 = beta * dInvSucav;

		dT2_R = af::matmul(T2_R, dT2 + dT2.T());
		out_dT2 = af::constant(0.0f, iD, ik * (ik + 1) / 2);
		diagIdx = CommonUtil<Scalar>::DiagIdx(ik);
		triIdx = CommonUtil<Scalar>::TriUpperIdx(ik);
		for (int d = 0; d < iD; d++)
		{
			dT2_Rd = dT2_R(af::span, af::span, d);
			T2_Rd = T2_R(af::span, af::span, d);
			dT2_Rd(diagIdx) = dT2_Rd(diagIdx) * T2_Rd(diagIdx);
			out_dT2(d, af::span) = dT2_Rd(triIdx);
		}
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::ComputePosteriorGradientU(af::array & dMu, af::array & dSu, af::array & out_dT1, af::array & out_dT2, af::array & out_dInvKuu)
	{
		af::array dSu_via_m(m_dType), dInvSu(m_dType), dT2_R(m_dType), dT2(m_dType), dT2_Rd(m_dType), 
			T2_Rd(m_dType), diagIdx(m_dType), triIdx(m_dType);

		dSu_via_m = af::matmulNT(dMu, T1);
		dSu += dSu_via_m;

		dInvSu = -af::matmul(af::matmul(afSu, dSu), afSu);

		out_dInvKuu = af::sum(dInvSu, 2);

		out_dT1 = af::matmul(afSu, dMu);
		dT2 = dInvSu;

		dT2_R = af::matmul(T2_R, dT2 + dT2.T());
		out_dT2 = af::constant(0.0f, iD, ik * (ik + 1) / 2);
		diagIdx = CommonUtil<Scalar>::DiagIdx(ik);
		triIdx = CommonUtil<Scalar>::TriUpperIdx(ik);
		for (int d = 0; d < iD; d++)
		{
			dT2_Rd = dT2_R(af::span, af::span, d);
			T2_Rd = T2_R(af::span, af::span, d);
			dT2_Rd(diagIdx) = dT2_Rd(diagIdx) * T2_Rd(diagIdx);
			out_dT2(d, af::span) = dT2_Rd(triIdx);
		}
	}

	template<typename Scalar>
	Scalar SGPLayer<Scalar>::ComputePhiPrior()
	{
		Scalar logDet = CommonUtil<Scalar>::LogDet(afKuu);
		
		//Scalar logDet = log(af::det<Scalar>(afKuu));

		Scalar logZ_prior = iD * 0.5 * logDet;
		return logZ_prior;
	}

	template<typename Scalar>
	Scalar SGPLayer<Scalar>::ComputePhiPosterior()
	{
		Scalar logDet = 0.0;
		af::array linsolve = af::constant(0.0, ik, 1, iD, (m_dType));
		for (int d = 0; d < iD; d++)
		{
			logDet += CommonUtil<Scalar>::LogDet(afSu(af::span, af::span, d));
			//logDet += log(af::det<Scalar>(afSu(af::span, af::span, d)));
			linsolve(af::span, 0, d) = CommonUtil<Scalar>::SolveQR(afSu(af::span, af::span, d), afMu(af::span, 0, d));
		}
			
		Scalar phi_posterior = 0.5 * logDet;
		phi_posterior += 0.5 * af::sum<Scalar>(af::sum(afMu * linsolve));
		return phi_posterior;
	}

	template<typename Scalar>
	Scalar SGPLayer<Scalar>::ComputePhiCavity()
	{
		Scalar logDet = 0.0;
		af::array linsolve = af::constant(0.0, ik, 1, iD, (m_dType));
		for (int d = 0; d < iD; d++)
		{
			logDet += CommonUtil<Scalar>::LogDet(afSuHat(af::span, af::span, d));
			//logDet += log(af::det<Scalar>(afSuHat(af::span, af::span, d)));
			linsolve(af::span, 0, d) = CommonUtil<Scalar>::SolveQR(afSuHat(af::span, af::span, d), afMuHat(af::span, 0, d));
		}
			
		Scalar phi_cavity = 0.5 * logDet;
		phi_cavity += 0.5 * af::sum<Scalar>(af::sum(afMuHat * linsolve));
		return phi_cavity;
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::UpdateParameters()
	{
		SparseGPBaseLayer::UpdateParameters();

		af::array tiledInvKuu = af::tile(afInvKuu, 1, 1, iD);

		afGamma = af::matmul(tiledInvKuu, afMu);
		afBeta = -tiledInvKuu + af::matmul(af::matmul(tiledInvKuu, afSu), tiledInvKuu);
		afSuMuMu = afSu + af::matmulNT(afMu, afMu);

		afBetaStochastic = -tiledInvKuu + af::matmul(af::matmul(tiledInvKuu, afSuMuMu), tiledInvKuu);
	}
}