/**
File:		MachineLearning/Models/GPModels/FgSparseDeepGPBaseModel.h

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgAEPSparseDGPSSM.h>
#include <cstdlib>
//#include <ctime>

namespace NeuralEngine::MachineLearning::GPModels::AEP
{
	template class SDGPSSM<float>;
	template class SDGPSSM<double>;

	template<typename Scalar>
	SDGPSSM<Scalar>::SDGPSSM(const af::array& Y, int latentDimension, HiddenLayerDescription description, Scalar alpha, 
		Scalar priorMean, Scalar priorVariance, af::array xControl, PropagationMode probMode, LogLikType lType, XInit emethod)
		: SparseDeepGPSSMBaseModel(Y, latentDimension, description, priorMean, priorVariance, xControl, probMode, lType, emethod), dAlpha(alpha)
	{
		dynLayer = new SGPLayer<Scalar>(iN - 1, vNumPseudosPerLayer[0], iq + iDControlDyn, iq);
		for (int i = 0; i < iNumLayer; i++)
			gpEmissLayer.push_back(new SGPLayer<Scalar>(iN, vNumPseudosPerLayer[i], vSize[i + 1], vSize[i]));
	}

	template<typename Scalar>
	SDGPSSM<Scalar>::SDGPSSM(const af::array& Y, int latentDimension, std::vector<HiddenLayerDescription> descriptions, 
		Scalar alpha, Scalar priorMean, Scalar priorVariance, af::array xControl, PropagationMode probMode, LogLikType lType, XInit emethod)
		: SparseDeepGPSSMBaseModel(Y, latentDimension, descriptions, priorMean, priorVariance, xControl, probMode, lType, emethod), dAlpha(alpha)
	{
		dynLayer = new SGPLayer<Scalar>(iN - 1, vNumPseudosPerLayer[0], iq + iDControlDyn, iq);
		for (int i = 0; i < iNumLayer; i++)
			gpEmissLayer.push_back(new SGPLayer<Scalar>(iN, vNumPseudosPerLayer[i], vSize[i + 1], vSize[i]));
	}

	template<typename Scalar>
	SDGPSSM<Scalar>::SDGPSSM()
		: SparseDeepGPSSMBaseModel(), dAlpha(0)
	{
	}

	template<typename Scalar>
	SDGPSSM<Scalar>::~SDGPSSM()
	{
	}

	template<typename Scalar>
	Scalar SDGPSSM<Scalar>::Function(const af::array& x, af::array& outGradient)
	{
		SetParameters(x);

		Scalar logZdyn, dlogZ_sn, scaleLogZDyn, logZemiss, scaleLogZEmiss, likGradient,
			scalePrior, scalePost, scaleCavity, xContribution, dynContribution, emissContribution, priorPhiX, cavPhiX, postPhiX;

		af::array yBatch, mcav, vcav, cav1, cav2, mcav_t, vcav_t, mcav_t1, vcav_t1, mcav_up, vcav_up, mprob, vprob,
			mpost, vpost, psi1dyn, psi2dyn;

		af::array dlogZ_dmProb, dlogZ_dvProb, dlogZ_dmt, dlogZ_dvt, dlogZ_dm, dlogZ_dv, dlogZ_dm_scale, dlogZ_dv_scale,
			dmcav_up, dvcav_up, dmcav_prev, dvcav_prev, dmcav_next, dvcav_next, gradsXcav, gradsXpost, gradsXlogZ,
			dmCav, dvCav, dmPost, dvPost;

		af::array mi, vi, psi1i, psi2i;

		af::array gpDynLayerGradients, emissLayerGradients;

		af::array prevIdx, nextIdx, upIdx;

		std::vector<af::array> mout, vout, psi1, psi2;

		int iStart = 0, iEnd = 0, iBatchSizeDyn, iBatchSizeEmiss;
		outGradient = af::constant(0.0f, GetNumParameters(), (m_dType));

		std::map<std::string, af::array> GradInputDyn, GradInputEmiss;

		SGPLayer<Scalar>& sDynlayer = dynamic_cast<SGPLayer<Scalar>&>(*dynLayer);

		if (iBatchSize >= iN)
		{
			iBatchSize = iN;
			yBatch = afY;
		}
		else
		{
			if (!GetParent())
			{
				srand(time(0));
				afIndexes = af::round(af::randu(iN) * iN)(af::seq(iBatchSize));

				int startIdx = rand() % (iN - iBatchSize);
				int endIdx = startIdx + iBatchSize;

				afIndexes = af::seq(startIdx, endIdx - 1);
				afDynIndexes = af::seq(startIdx, endIdx - 2);
			}

			for (uint i = 0; i < GetNumChildren(); i++)
			{
				GPLVMBaseModel<Scalar>& child = dynamic_cast<GPLVMBaseModel<Scalar>&>(*GetChild(i));
				child.SetIndexes(afIndexes);
			}

			yBatch = afY(afIndexes, af::span);
		}

		iBatchSizeDyn = afDynIndexes.dims(0);
		scaleLogZDyn = -(iN - 1) * 1.0 / iBatchSizeDyn / dAlpha;
		iBatchSizeEmiss = afIndexes.dims(0);
		scaleLogZEmiss = -iN * 1.0 / iBatchSizeEmiss / dAlpha;

		CavityLatents(mcav, vcav, cav1, cav2);

		// compute cavity factors for the latent variables
		prevIdx = afDynIndexes + 1;
		mcav_t = mcav(prevIdx, af::span);
		vcav_t = vcav(prevIdx, af::span);

		nextIdx = afDynIndexes;
		mcav_t1 = mcav(nextIdx, af::span);
		vcav_t1 = vcav(nextIdx, af::span);

		if (iDControlDyn > 0)
		{
			mcav_t1 = CommonUtil<Scalar>::Join(mcav_t1, afXControl(nextIdx, af::span), 1);
			vcav_t1 = CommonUtil<Scalar>::Join(vcav_t1, af::constant(0.0, iBatchSizeDyn, iDControlDyn, m_dType), 1);
		}

		upIdx = afIndexes;
		mcav_up = mcav(upIdx, af::span);
		vcav_up = vcav(upIdx, af::span);
		if (iDControlEmiss > 0)
		{
			mcav_up = CommonUtil<Scalar>::Join(mcav_up, afXControl(upIdx, af::span), 1);
			vcav_up = CommonUtil<Scalar>::Join(vcav_up, af::constant(0.0, iBatchSizeEmiss, iDControlEmiss, m_dType), 1);
		}

		int iStartLayer = 0, iEndLayer = 0;
		switch (pMode)
		{
		case PropagationMode::MomentMatching:
			// deal with transition factors
			sDynlayer.ForwardPredictionCavity(mprob, vprob, &psi1dyn, &psi2dyn, mcav_t1, &vcav_t1, dAlpha);
			logZdyn = ComputeTiltedTransition(mprob, vprob, mcav_t, vcav_t, scaleLogZDyn, dlogZ_dmProb, dlogZ_dvProb, dlogZ_dmt, dlogZ_dvt, dlogZ_sn);
			gpDynLayerGradients = sDynlayer.BackpropGradientsMM(mprob, vprob, dlogZ_dmProb, dlogZ_dvProb, psi1dyn, psi2dyn, mcav_t1, vcav_t1, &GradInputDyn, dAlpha);

			// deal with emission factors
			emissContribution = 0;
			for (uint i = 0; i < iNumLayer; i++)
			{
				SGPLayer<Scalar>& slayer = dynamic_cast<SGPLayer<Scalar>&>(*gpEmissLayer[i]);
				if (i == 0)
				{
					slayer.ForwardPredictionCavity(mi, vi, &psi1i, &psi2i, mcav_up, &vcav_up, dAlpha);
					mout.push_back(mi);
					vout.push_back(vi);
					psi1.push_back(psi1i);
					psi2.push_back(psi2i);
				}
				else
				{
					slayer.ForwardPredictionCavity(mi, vi, &psi1i, &psi2i, mout[i - 1], &vout[i - 1], dAlpha);
					mout.push_back(mi);
					vout.push_back(vi);
					psi1.push_back(psi1i);
					psi2.push_back(psi2i);
				}
				emissContribution += slayer.ComputePhi(dAlpha);
			}

			logZemiss = likLayer->ComputeLogZ(mout.back(), vout.back(), yBatch, dAlpha);
			likLayer->ComputeLogZGradients(mout.back(), vout.back(), yBatch, &dlogZ_dm, &dlogZ_dv, nullptr, dAlpha);

			logZemiss = logZemiss * scaleLogZEmiss;
			dlogZ_dm_scale = dlogZ_dm * scaleLogZEmiss;
			dlogZ_dv_scale = dlogZ_dv * scaleLogZEmiss;

			// Backpropagation
			emissLayerGradients = af::constant(0.0, GetNumGPLayerParameters(), m_dType);
			for (int i = iNumLayer - 1; i >= 0; i--)
			{
				SGPLayer<Scalar>& slayer = dynamic_cast<SGPLayer<Scalar>&>(*gpEmissLayer[i]);

				iEndLayer += slayer.GetNumParameters();
				if (i == 0)
					emissLayerGradients(af::seq(iStartLayer, iEndLayer - 1)) = slayer.BackpropGradientsMM(mout[i], vout[i], dlogZ_dm_scale,
						dlogZ_dv_scale, psi1[i], psi2[i], mcav_up, vcav_up, &GradInputEmiss, dAlpha);
				else
				{
					emissLayerGradients(af::seq(iStartLayer, iEndLayer - 1)) = slayer.BackpropGradientsMM(mout[i], vout[i], dlogZ_dm_scale,
						dlogZ_dv_scale, psi1[i], psi2[i], mout[i - 1], vout[i - 1], &GradInputEmiss, dAlpha);
					dlogZ_dm_scale = GradInputEmiss["dL_dmx"];
					dlogZ_dv_scale = GradInputEmiss["dL_dvx"];
				}
				iStartLayer = iEndLayer;
			}
			break;

		case MonteCarlo:
			// deal with transition factors
			af::array x, eps;
			sDynlayer.ForwardPredictionRandomCavityMC(mprob, vprob, x, eps, mcav_t1, vcav_t1, dAlpha);
			logZdyn = ComputeTiltedTransition(mprob, vprob, mcav_t, vcav_t, scaleLogZDyn, dlogZ_dmProb, dlogZ_dvProb, dlogZ_dmt, dlogZ_dvt, dlogZ_sn);
			sDynlayer.BackpropGradientsMC(mcav_t1, vcav_t1, eps, dlogZ_dmProb, dlogZ_dvProb, x, &GradInputDyn, dAlpha);

			// deal with emission factors
			emissContribution = 0;
			for (uint i = 0; i < iNumLayer; i++)
			{
				SGPLayer<Scalar>& slayer = dynamic_cast<SGPLayer<Scalar>&>(*gpEmissLayer[i]);
				if (i == 0)
				{
					slayer.ForwardPredictionRandomCavityMC(mi, vi, psi1i, psi2i, mcav_up, vcav_up, dAlpha);

					mi = moddims(mi, 5 * mi.dims(1), mi.dims(2));
					vi = moddims(vi, 5 * vi.dims(1), vi.dims(2));

					mout.push_back(mi);
					vout.push_back(vi);
					psi1.push_back(psi1i);
					psi2.push_back(psi2i);
				}
				else
				{
					slayer.ForwardPredictionCavity(mi, vi, &psi1i, &psi2i, mout[i - 1], &vout[i - 1], dAlpha);
					mout.push_back(mi);
					vout.push_back(vi);
					psi1.push_back(psi1i);
					psi2.push_back(psi2i);
				}
				emissContribution += slayer.ComputePhi(dAlpha);
			}

			logZemiss = likLayer->ComputeLogZ(mout.back(), vout.back(), yBatch, dAlpha);
			likLayer->ComputeLogZGradients(mout.back(), vout.back(), yBatch, &dlogZ_dm, &dlogZ_dv, nullptr, dAlpha);

			logZemiss = logZemiss * scaleLogZEmiss;
			dlogZ_dm_scale = dlogZ_dm * scaleLogZEmiss;
			dlogZ_dv_scale = dlogZ_dv * scaleLogZEmiss;

			// Backpropagation
			emissLayerGradients = af::constant(0.0, GetNumGPLayerParameters(), m_dType);
			for (int i = iNumLayer - 1; i >= 0; i--)
			{
				SGPLayer<Scalar>& slayer = dynamic_cast<SGPLayer<Scalar>&>(*gpEmissLayer[i]);

				iEndLayer += slayer.GetNumParameters();
				if (i == 0)
					emissLayerGradients(af::seq(iStartLayer, iEndLayer - 1)) = slayer.BackpropGradientsMC(mcav_up, vcav_up, 
						psi2[i], dlogZ_dm_scale, dlogZ_dv_scale, psi1[i], &GradInputEmiss, dAlpha);
				else
				{
					emissLayerGradients(af::seq(iStartLayer, iEndLayer - 1)) = slayer.BackpropGradientsMM(mout[i], vout[i], dlogZ_dm_scale,
						dlogZ_dv_scale, psi1[i], psi2[i], mout[i - 1], vout[i - 1], &GradInputEmiss, dAlpha);
					
					dlogZ_dm_scale = GradInputEmiss["dL_dmx"];
					dlogZ_dv_scale = GradInputEmiss["dL_dvx"];
				}
				iStartLayer = iEndLayer;
			}
			break;
		}

		

		//// deal with emission factors
		//emissContribution = 0;
		//for (uint i = 0; i < iNumLayer; i++)
		//{
		//	SGPLayer<Scalar>& slayer = dynamic_cast<SGPLayer<Scalar>&>(*gpEmissLayer[i]);
		//	if (i == 0)
		//	{
		//		slayer.ForwardPredictionCavity(mi, vi, &psi1i, &psi2i, mcav_up, &vcav_up, dAlpha);
		//		mout.push_back(mi);
		//		vout.push_back(vi);
		//		psi1.push_back(psi1i);
		//		psi2.push_back(psi2i);
		//	}
		//	else
		//	{
		//		slayer.ForwardPredictionCavity(mi, vi, &psi1i, &psi2i, mout[i - 1], &vout[i - 1], dAlpha);
		//		mout.push_back(mi);
		//		vout.push_back(vi);
		//		psi1.push_back(psi1i);
		//		psi2.push_back(psi2i);
		//	}
		//	emissContribution += slayer.ComputePhi(dAlpha);
		//}

		//logZemiss = likLayer->ComputeLogZ(mout.back(), vout.back(), yBatch, dAlpha);
		//likLayer->ComputeLogZGradients(mout.back(), vout.back(), yBatch, &dlogZ_dm, &dlogZ_dv, nullptr, dAlpha);
		//
		//logZemiss = logZemiss * scaleLogZEmiss;
		//dlogZ_dm_scale = dlogZ_dm * scaleLogZEmiss;
		//dlogZ_dv_scale = dlogZ_dv * scaleLogZEmiss;

		//// Backpropagation
		//emissLayerGradients = af::constant(0.0, GetNumGPLayerParameters(), m_dType);
		//int iStartLayer = 0, iEndLayer = 0;
		//for (int i = iNumLayer - 1; i >= 0; i--)
		//{
		//	SGPLayer<Scalar>& slayer = dynamic_cast<SGPLayer<Scalar>&>(*gpEmissLayer[i]);

		//	iEndLayer += slayer.GetNumParameters();
		//	if (i == 0)
		//		emissLayerGradients(af::seq(iStartLayer, iEndLayer - 1)) = slayer.BackpropGradientsMM(mout[i], vout[i], dlogZ_dm_scale, 
		//			dlogZ_dv_scale, psi1[i], psi2[i], mcav_up, vcav_up, &GradInputEmiss, dAlpha);
		//	else
		//	{
		//		emissLayerGradients(af::seq(iStartLayer, iEndLayer - 1)) = slayer.BackpropGradientsMM(mout[i], vout[i], dlogZ_dm_scale, 
		//			dlogZ_dv_scale, psi1[i], psi2[i], mout[i - 1], vout[i - 1], &GradInputEmiss, dAlpha);
		//		dlogZ_dm_scale = GradInputEmiss["dL_dmx"];
		//		dlogZ_dv_scale = GradInputEmiss["dL_dvx"];
		//	}
		//	iStartLayer = iEndLayer;
		//}

		likGradient = likLayer->BackpropagationGradients(mout.back(), vout.back(), dlogZ_dm, dlogZ_dv, dAlpha, scaleLogZEmiss);

		dmcav_up = GradInputEmiss["dL_dmx"](af::span, af::seq(0, iq - 1));
		dvcav_up = GradInputEmiss["dL_dvx"](af::span, af::seq(0, iq - 1));
		dmcav_prev = dlogZ_dmt;
		dvcav_prev = dlogZ_dvt;
		dmcav_next = GradInputDyn["dL_dmx"](af::span, af::seq(0, iq - 1));
		dvcav_next = GradInputDyn["dL_dvx"](af::span, af::seq(0, iq - 1));

		// compute posterior
		gradsXpost = PosteriorGradientLatents();
		gradsXcav = CavityGradientLatents(cav1, cav2);
		gradsXlogZ = LogZGradientLatents(cav1, cav2, dmcav_up, dvcav_up, dmcav_prev, dvcav_prev, dmcav_next, dvcav_next);

		/// Collecting gradients
		iEnd = likLayer->GetNumParameters();
		if (iStart != iEnd)
			outGradient(af::seq(iStart, iEnd - 1)) = likGradient;

		iStart = iEnd; iEnd += iN * iq * 2.0;
		outGradient(af::seq(iStart, iEnd - 1)) = gradsXcav + gradsXpost + gradsXlogZ;

		iStart = iEnd; iEnd += 1;
		outGradient(af::seq(iStart, iEnd - 1)) = dlogZ_sn;

		iStart = iEnd; iEnd += sDynlayer.GetNumParameters();
		outGradient(af::seq(iStart, iEnd - 1)) = gpDynLayerGradients;

		iStart = iEnd; iEnd += GetNumGPLayerParameters();
		outGradient(af::seq(iStart, iEnd - 1)) = emissLayerGradients;

		outGradient /= iN;

		// compute objective
		dynContribution = sDynlayer.ComputePhi(dAlpha);

		priorPhiX = ComputePhiPriorLatents();
		cavPhiX = ComputePhiCavityLatents();
		postPhiX = ComputePhiPosteriorLatents();

		xContribution = priorPhiX + postPhiX + cavPhiX;

		return (logZdyn + logZemiss + xContribution + dynContribution + emissContribution) / iN;
	}

	template<typename Scalar>
	void SDGPSSM<Scalar>::CavityLatents(af::array& mcav, af::array& vcav, af::array& cav1, af::array& cav2)
	{
		cav1 = afPosteriorX1 - dAlpha * afFactorX1;
		cav2 = afPosteriorX2 - dAlpha * afFactorX2;

		mcav = cav1 / (cav2 + 1e-16);
		vcav = 1.0 / (cav2 + 1e-16);
	}

	template<typename Scalar>
	Scalar SDGPSSM<Scalar>::ComputeTiltedTransition(const af::array& mprob, const af::array& vprob, const af::array& mcav_t1,
		const af::array& vcav_t1, Scalar scaleLogZDyn, af::array& dlogZ_dmProb, af::array& dlogZ_dvProb, af::array& dlogZ_dmt,
		af::array& dlogZ_dvt, Scalar& dlogZ_sn)
	{
		Scalar sn2 = exp(2.0 * dSn);

		Scalar logZ = 0;
		if (mprob.dims(0) > mprob.dims(1))
		{
			af::array vSum = vcav_t1 + vprob + sn2 / dAlpha;
			af::array mDiff = mcav_t1 - mprob;
			af::array expTerm = -0.5 * af::pow(mDiff, 2.0) / vSum;

			Scalar alphaTerm = -0.5 * dAlpha * log(2.0 * af::Pi * sn2);
			af::array constTerm = -0.5 * af::log(1.0 + dAlpha * (vcav_t1 + vprob) / sn2);
			af::array logZtmp = expTerm + constTerm + alphaTerm;

			logZ = scaleLogZDyn * af::sum<Scalar>(af::sum(logZtmp));

			dlogZ_dvt = scaleLogZDyn * (-0.5 / vSum + 0.5 * af::pow(mDiff, 2.0) / af::pow(vSum, 2.0));
			dlogZ_dvProb = dlogZ_dvt;
			dlogZ_dmt = scaleLogZDyn * (-mDiff / vSum);
			dlogZ_dmProb = -dlogZ_dmt;
			Scalar dvSum = af::sum<Scalar>(af::sum(dlogZ_dvt));
			dlogZ_sn = dvSum * 2.0 * sn2 / dAlpha + scaleLogZDyn * mprob.dims(0) * iq * (1.0 - dAlpha);
		}
		else
		{
			af:: array vcav_t1_tile = af::tile(af::moddims(vcav_t1, 1, vcav_t1.dims(0), vcav_t1.dims(1)), vprob.dims(0));
			af:: array mcav_t1_tile = af::tile(af::moddims(mcav_t1, 1, mcav_t1.dims(0), mcav_t1.dims(1)), mprob.dims(0));


			af::array vSum = vcav_t1_tile + vprob + sn2 / dAlpha;
			af::array mDiff = mcav_t1_tile - mprob;
			af::array expTerm = -0.5 * af::pow(mDiff, 2.0) / vSum;

			Scalar alphaTerm = -0.5 * dAlpha * log(2.0 * af::Pi * sn2);
			af::array constTerm = -0.5 * af::log(1.0 + dAlpha * (vcav_t1_tile + vprob) / sn2);
			af::array logZtmp = expTerm + constTerm + alphaTerm;

			af::array logZ_max = af::max(logZtmp, 0);
			af::array exp_term = af::exp(logZtmp - af::tile(logZ_max, mprob.dims(0)));
			af::array sumexp = af::sum(exp_term, 0);
			af::array logZ_lse = logZ_max + af::log(sumexp);
			logZ_lse -= log(mprob.dims(0));
			logZ = scaleLogZDyn * af::sum<Scalar>(af::sum(logZ_lse));

			af::array dlogZ = scaleLogZDyn * exp_term / af::tile(sumexp, mprob.dims(0));
			af::array dlogZ_dm = dlogZ * mDiff / vSum;
			af::array dlogZ_dv = dlogZ * (-0.5 / vSum + 0.5 * af::pow(mDiff, 2.0) / af::pow(vSum, 2.0));

			dlogZ_dmt = -af::moddims(af::sum(dlogZ_dm, 0), mcav_t1.dims(0), mcav_t1.dims(1));
			dlogZ_dmProb = dlogZ_dm;
			dlogZ_dvt = af::moddims(af::sum(dlogZ_dv, 0), vcav_t1.dims(0), vcav_t1.dims(1));
			dlogZ_dvProb = dlogZ_dv;
			Scalar dvSum = af::sum<Scalar>(af::sum(dlogZ_dv));
			dlogZ_sn = dvSum * 2.0 * sn2 / dAlpha + scaleLogZDyn * mprob.dims(1) * iD * (1.0 - dAlpha);
		}

		return logZ;
	}

	template<typename Scalar>
	af::array SDGPSSM<Scalar>::PosteriorGradientLatents()
	{
		af::array gradX1, gradX2, gradFactorX1, gradFactorX2;
		af::array gradX = af::constant(0.0, iN * iq * 2, m_dType);

		af::array scalePost = af::constant(-(1.0 - 1.0 / dAlpha), iN/*afIndexes.dims(0)*/, iq, m_dType);
		if (afSegments.isempty())
		{
			scalePost(af::seq(0, iN - 2), af::span) = scalePost(af::seq(0, iN - 2), af::span) + 1.0 / dAlpha;
			scalePost(af::seq(1, iN - 1), af::span) = scalePost(af::seq(1, iN - 1), af::span) + 1.0 / dAlpha;
		}
		else
		{
			int iStartIdx, iEndIdx;
			for (auto i = 0; i < afSegments.dims(0) - 1; i++)
			{
				iStartIdx = afSegments(i).as(s32).scalar<int>();
				iEndIdx = afSegments(i + 1).as(s32).scalar<int>() - 1;

				scalePost(af::seq(iStartIdx, iEndIdx - 1), af::span) = scalePost(af::seq(iStartIdx, iEndIdx - 1), af::span) + 1.0 / dAlpha;
				scalePost(af::seq(iStartIdx + 1, iEndIdx), af::span) = scalePost(af::seq(iStartIdx + 1, iEndIdx), af::span) + 1.0 / dAlpha;
			}
			iStartIdx = afSegments(af::end).as(s32).scalar<int>();

			scalePost(af::seq(iStartIdx, af::end - 1), af::span) = scalePost(af::seq(iStartIdx, af::end - 1), af::span) + 1.0 / dAlpha;
			scalePost(af::seq(iStartIdx + 1, af::end), af::span) = scalePost(af::seq(iStartIdx + 1, af::end), af::span) + 1.0 / dAlpha;
		}

		int iStart = 0, iEnd = iN * iq;

		gradX1 = afPosteriorX1 / afPosteriorX2;
		gradX2 = -0.5 * af::pow(afPosteriorX1, 2.0) / af::pow(afPosteriorX2, 2.0) - 0.5 / afPosteriorX2;

		gradX1 *= scalePost;
		gradX2 *= scalePost;

		gradFactorX1 = gradX1 * 3.0;
		gradFactorX2 = gradX2 * 6.0 * afFactorX2;

		if (afSegments.isempty())
		{
			gradFactorX1(0, af::span) = 2.0 * gradX1(0, af::span);
			gradFactorX1(af::end, af::span) = 2.0 * gradX1(af::end, af::span);
			gradFactorX2(0, af::span) = 4.0 * gradX2(0, af::span) * afFactorX2(0, af::span);
			gradFactorX2(af::end, af::span) = 4.0 * gradX2(af::end, af::span) * afFactorX2(af::end, af::span);
		}
		else
		{
			gradFactorX1(afSegments, af::span) = 2.0 * gradX1(afSegments, af::span);
			gradFactorX1(afSegments(af::seq(1, af::end)) - 1, af::span) = 2.0 * gradX1(afSegments(af::seq(1, af::end)) - 1, af::span);
			gradFactorX1(af::end, af::span) = 2.0 * gradX1(af::end, af::span);
			gradFactorX2(afSegments, af::span) = 4.0 * gradX2(afSegments, af::span) * afFactorX2(afSegments, af::span);
			gradFactorX2(afSegments(af::seq(1, af::end)) - 1, af::span) = 4.0 * gradX2(afSegments(af::seq(1, af::end)) - 1, af::span) * afFactorX2(afSegments(af::seq(1, af::end)) - 1, af::span);
			gradFactorX2(af::end, af::span) = 4.0 * gradX2(af::end, af::span) * afFactorX2(af::end, af::span);
		}

		gradX(af::seq(iStart, iEnd - 1)) = af::flat(gradFactorX1);
		iStart = iEnd; iEnd += iN * iq;
		gradX(af::seq(iStart, iEnd - 1)) = af::flat(gradFactorX2);

		return gradX;
	}

	template<typename Scalar>
	af::array SDGPSSM<Scalar>::CavityGradientLatents(const af::array& cav1, const af::array& cav2)
	{
		int iStart = 0, iEnd = iN * iq;
		af::array gradX = af::constant(0.0, iN * iq * 2, m_dType);

		Scalar scale = -1.0 / dAlpha;
		af::array dcav1 = cav1 / cav2;
		af::array dcav2 = -0.5 * pow(cav1, 2.0) / pow(cav2, 2.0) - 0.5 / cav2;

		af::array scaleCav = af::constant(scale, iN/*afIndexes.dims(0)*/, iq, m_dType);
		if (afSegments.isempty())
		{
			scaleCav(af::seq(0, iN - 2), af::span) = scaleCav(af::seq(0, iN - 2), af::span) + scale;
			scaleCav(af::seq(1, iN - 1), af::span) = scaleCav(af::seq(1, iN - 1), af::span) + scale;
		}
		else
		{
			int iStartIdx, iEndIdx;
			for (auto i = 0; i < afSegments.dims(0) - 1; i++)
			{
				iStartIdx = afSegments(i).as(s32).scalar<int>();
				iEndIdx = afSegments(i + 1).as(s32).scalar<int>() - 1;

				scaleCav(af::seq(iStartIdx, iEndIdx - 1), af::span) = scaleCav(af::seq(iStartIdx, iEndIdx - 1), af::span) + scale;
				scaleCav(af::seq(iStartIdx + 1, iEndIdx), af::span) = scaleCav(af::seq(iStartIdx + 1, iEndIdx), af::span) + scale;
			}
			iStartIdx = afSegments(af::end).as(s32).scalar<int>();

			scaleCav(af::seq(iStartIdx, af::end - 1), af::span) = scaleCav(af::seq(iStartIdx, af::end - 1), af::span) + scale;
			scaleCav(af::seq(iStartIdx + 1, af::end), af::span) = scaleCav(af::seq(iStartIdx + 1, af::end), af::span) + scale;
		}

		af::array gradsX1 = scaleCav * dcav1;
		af::array gradsX2 = scaleCav * dcav2;

		scaleCav = af::constant(3.0 - dAlpha, iN/*afIndexes.dims(0)*/, iq, m_dType);
		if (afSegments.isempty())
		{
			scaleCav(0, af::span) = 2.0 - dAlpha;
			scaleCav(af::end, af::span) = 2.0 - dAlpha;
		}
		else
		{
			scaleCav(afSegments, af::span) = 2.0 - dAlpha;
			scaleCav(afSegments(af::seq(1, af::end)) - 1, af::span) = 2.0 - dAlpha;
			scaleCav(af::end, af::span) = 2.0 - dAlpha;
		}

		gradsX1 *= scaleCav;
		gradsX2 *= scaleCav;

		gradsX2 = gradsX2 * 2 * afFactorX2;

		gradX(af::seq(iStart, iEnd - 1)) = af::flat(gradsX1);
		iStart = iEnd; iEnd += iN * iq;
		gradX(af::seq(iStart, iEnd - 1)) = af::flat(gradsX2);

		return gradX;
	}

	template<typename Scalar>
	af::array SDGPSSM<Scalar>::LogZGradientLatents(const af::array& cav1, const af::array& cav2, const af::array& dmcav_up, const af::array& dvcav_up,
		const af::array& dmcav_prev, const af::array& dvcav_prev, const af::array& dmcav_next, const af::array& dvcav_next)
	{
		int iStart = 0, iEnd = iN * iq;
		int iStartIdx, iEndIdx;

		af::array gradX = af::constant(0.0, iN * iq * 2, m_dType);

		af::array gradsX1 = constant(0.0, cav1.dims(), m_dType);
		af::array gradsX2 = constant(0.0, gradsX1.dims(), m_dType);

		af::array gradsUp1 = dmcav_up / cav2(afIndexes, af::span);
		af::array gradsUp2 = -dmcav_up * cav1(afIndexes, af::span) / af::pow(cav2(afIndexes, af::span), 2.0) - dvcav_up
			/ af::pow(cav2(afIndexes, af::span), 2.0);
		gradsX1(afIndexes, af::span) = gradsUp1;
		gradsX2(afIndexes, af::span) = gradsUp2;

		if (afSegments.isempty())
		{
			iStartIdx = afIndexes(0).as(s32).scalar<int>() + 1;
			iEndIdx = afIndexes(af::end).as(s32).scalar<int>();
			af::array idxs = af::seq(iStartIdx, iEndIdx);

			af::array gradsPrev1 = dmcav_prev / cav2(idxs, af::span);
			af::array gradsPrev2 = -dmcav_prev * cav1(idxs, af::span) / af::pow(cav2(idxs, af::span), 2.0) - dvcav_prev / af::pow(cav2(idxs, af::span), 2.0);
			gradsX1(idxs, af::span) += gradsPrev1;
			gradsX2(idxs, af::span) += gradsPrev2;

			iStartIdx = afIndexes(0).as(s32).scalar<int>();
			iEndIdx = afIndexes(af::end).as(s32).scalar<int>() - 1;
			idxs = af::seq(iStartIdx, iEndIdx);
			af::array gradsNext1 = dmcav_next / cav2(idxs, af::span);
			af::array gradsNext2 = -dmcav_next * cav1(idxs, af::span) / af::pow(cav2(idxs, af::span), 2.0) - dmcav_next / af::pow(cav2(idxs, af::span), 2.0);
			gradsX1(idxs, af::span) += gradsNext1;
			gradsX2(idxs, af::span) += gradsNext2;
		}
		else
		{
			af::array prevIdx, nextIdx;
			for (auto i = 0; i < afSegments.dims(0) - 1; i++)
			{
				iStartIdx = afSegments(i).as(s32).scalar<int>() + 1;
				iEndIdx = afSegments(i + 1).as(s32).scalar<int>() - 1;
				prevIdx = CommonUtil<Scalar>::Join(prevIdx, af::seq(iStartIdx, iEndIdx));

				iStartIdx = afSegments(i).as(s32).scalar<int>();
				iEndIdx = afSegments(i + 1).as(s32).scalar<int>() - 2;
				nextIdx = CommonUtil<Scalar>::Join(nextIdx, af::seq(iStartIdx, iEndIdx));
				
			}
			iStartIdx = afSegments(af::end).as(s32).scalar<int>() + 1;
			iEndIdx = afIndexes(af::end).as(s32).scalar<int>();
			prevIdx = CommonUtil<Scalar>::Join(prevIdx, af::seq(iStartIdx, iEndIdx));


			iStartIdx = afSegments(af::end).as(s32).scalar<int>();
			iEndIdx = afIndexes(af::end).as(s32).scalar<int>() - 1;
			nextIdx = CommonUtil<Scalar>::Join(nextIdx, af::seq(iStartIdx, iEndIdx));

			af::array gradsPrev1 = dmcav_prev / cav2(prevIdx, af::span);
			af::array gradsPrev2 = -dmcav_prev * cav1(prevIdx, af::span) / af::pow(cav2(prevIdx, af::span), 2.0) - dvcav_prev / af::pow(cav2(prevIdx, af::span), 2.0);
			gradsX1(prevIdx, af::span) += gradsPrev1;
			gradsX2(prevIdx, af::span) += gradsPrev2;

			af::array gradsNext1 = dmcav_next / cav2(nextIdx, af::span);
			af::array gradsNext2 = -dmcav_next * cav1(nextIdx, af::span) / af::pow(cav2(nextIdx, af::span), 2.0) - dmcav_next / af::pow(cav2(nextIdx, af::span), 2.0);
			gradsX1(nextIdx, af::span) += gradsNext1;
			gradsX2(nextIdx, af::span) += gradsNext2;
		}

		af::array scaleXcav = af::constant(3.0 - dAlpha, iN, iq, m_dType);
		if (afSegments.isempty())
		{
			scaleXcav(0, af::span) = 2.0 - dAlpha;
			scaleXcav(af::end, af::span) = 2.0 - dAlpha;
		}
		else
		{
			scaleXcav(afSegments, af::span) = 2.0 - dAlpha;
			scaleXcav(afSegments(af::seq(1, af::end)) - 1, af::span) = 2.0 - dAlpha;
			scaleXcav(af::end, af::span) = 2.0 - dAlpha;
		}

		gradsX1 *= scaleXcav;
		gradsX2 *= scaleXcav;
		gradsX2 *= 2.0 * afFactorX2;

		gradX(af::seq(iStart, iEnd - 1)) = flat(gradsX1);
		iStart = iEnd; iEnd += iN * iq;
		gradX(af::seq(iStart, iEnd - 1)) = flat(gradsX2);

		return gradX;
	}

	template<typename Scalar>
	Scalar SDGPSSM<Scalar>::ComputePhiPriorLatents()
	{
		Scalar m = dPriorX1 / dPriorX2;
		Scalar v = 1.0 / dPriorX2;

		return 0.5 * iq * (pow(m, 2.0) / v + log(v));
	}

	template<typename Scalar>
	Scalar SDGPSSM<Scalar>::ComputePhiCavityLatents()
	{
		af::array cav1 = afPosteriorX1 - dAlpha * afFactorX1;
		af::array cav2 = afPosteriorX2 - dAlpha * afFactorX2;
		af::array phiCav = 0.5 * (pow(cav1, 2.0) / cav2 - log(cav2));

		af::array scaleCav = af::constant(-1.0 / dAlpha, iN/*afIndexes.dims(0)*/, iq, m_dType);

		if (afSegments.isempty())
		{
			scaleCav(af::seq(0, iN - 2), af::span) = scaleCav(af::seq(0, iN - 2), af::span) - 1.0 / dAlpha;
			scaleCav(af::seq(1, iN - 1), af::span) = scaleCav(af::seq(1, iN - 1), af::span) - 1.0 / dAlpha;
		}
		else
		{
			int iStartIdx, iEndIdx;
			for (auto i = 0; i < afSegments.dims(0) - 1; i++)
			{
				iStartIdx = afSegments(i).as(s32).scalar<int>();
				iEndIdx = afSegments(i + 1).as(s32).scalar<int>() - 1;

				scaleCav(af::seq(iStartIdx, iEndIdx - 1), af::span) = scaleCav(af::seq(iStartIdx, iEndIdx - 1), af::span) - 1.0 / dAlpha;
				scaleCav(af::seq(iStartIdx + 1, iEndIdx), af::span) = scaleCav(af::seq(iStartIdx + 1, iEndIdx), af::span) - 1.0 / dAlpha;
			}
			iStartIdx = afSegments(af::end).as(s32).scalar<int>();

			scaleCav(af::seq(iStartIdx, af::end - 1), af::span) = scaleCav(af::seq(iStartIdx, af::end - 1), af::span) - 1.0 / dAlpha;
			scaleCav(af::seq(iStartIdx + 1, af::end), af::span) = scaleCav(af::seq(iStartIdx + 1, af::end), af::span) - 1.0 / dAlpha;
		}

		return af::sum<Scalar>(af::sum(scaleCav * phiCav));
	}

	template<typename Scalar>
	Scalar SDGPSSM<Scalar>::ComputePhiPosteriorLatents()
	{
		af::array phiPost = 0.5 * (pow(afPosteriorX1, 2.0) / afPosteriorX2 - log(afPosteriorX2));

		af::array scalePost = af::constant(-(1.0 - 1.0 / dAlpha), iN/*afIndexes.dims(0)*/, iq, m_dType);
		if (afSegments.isempty())
		{
			scalePost(af::seq(0, iN - 2), af::span) = scalePost(af::seq(0, iN - 2), af::span) + 1.0 / dAlpha;
			scalePost(af::seq(1, iN - 1), af::span) = scalePost(af::seq(1, iN - 1), af::span) + 1.0 / dAlpha;
		}
		else
		{
			int iStartIdx, iEndIdx;
			for (auto i = 0; i < afSegments.dims(0) - 1; i++)
			{
				iStartIdx = afSegments(i).as(s32).scalar<int>();
				iEndIdx = afSegments(i + 1).as(s32).scalar<int>() - 1;

				scalePost(af::seq(iStartIdx, iEndIdx - 1), af::span) = scalePost(af::seq(iStartIdx, iEndIdx - 1), af::span) + 1.0 / dAlpha;
				scalePost(af::seq(iStartIdx + 1, iEndIdx), af::span) = scalePost(af::seq(iStartIdx + 1, iEndIdx), af::span) + 1.0 / dAlpha;
			}
			iStartIdx = afSegments(af::end).as(s32).scalar<int>();

			scalePost(af::seq(iStartIdx, af::end - 1), af::span) = scalePost(af::seq(iStartIdx, af::end - 1), af::span) + 1.0 / dAlpha;
			scalePost(af::seq(iStartIdx + 1, af::end), af::span) = scalePost(af::seq(iStartIdx + 1, af::end), af::span) + 1.0 / dAlpha;
		}

		return af::sum<Scalar>(af::sum(scalePost * phiPost));
	}
}