/**
File:		MachineLearning/GPModels/Models/GPModels/FgAEPSparseGPLVM.cpp

Author:		
Email:		
Site:       

Copyright (c) 2018 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgAEPSparseGPLVM.h>
#include <MachineLearning/FgAEPSparseGPLayer.h>

namespace NeuralEngine::MachineLearning::GPModels::AEP
{
	template class SGPLVM<double>;
	template class SGPLVM<float>;

	template<typename Scalar>
	SGPLVM<Scalar>::SGPLVM(const af::array& Y, int latentDimension, int numInducing, Scalar alpha, Scalar priorMean, Scalar priorVariance, PropagationMode probMode, LogLikType lType, XInit emethod)
		: SparseGPLVMBaseModel(Y, latentDimension, priorMean, priorVariance, numInducing, lType, emethod), dAlpha(alpha), pMode(probMode)
	{
		gpLayer = new SGPLayer<Scalar>(iN, ik, iD, iq);
	}

	template<typename Scalar>
	SGPLVM<Scalar>::SGPLVM() 
		: SparseGPLVMBaseModel(), dAlpha(0.5)
	{
	}

	template<typename Scalar>
	Scalar SGPLVM<Scalar>::Function(const af::array& x, af::array& outGradient)
	{
		if (!x.isempty())
			SetParameters(x);

		Scalar logZ = 0, scaleLogZ, scalePrior, scalePost, scaleCavity, xContribution, sgpContribution, priorPhiX, cavPhiX, postPhiX, likGradient;

		af::array yBatch, mcav, vcav, mpost, vpost, mout, vout, psi1, psi2, gpLayerGradients;

		af::array dlogZ_dm, dlogZ_dv, dlogZ_dm_scale, dlogZ_dv_scale, dmCav, dvCav, dmPost, dvPost, gradsXcav, gradsXpost, tmpGradient;

		int iStart = 0, iEnd = 0;
		outGradient = af::constant(0.0f, GetNumParameters(), (m_dType));
		std::map<std::string, af::array> GradInput;

		SGPLayer<Scalar>& slayer = dynamic_cast<SGPLayer<Scalar>&>(*gpLayer);

		if (iBatchSize >= iN)
		{
			iBatchSize = iN;
			yBatch = afY;
		}
		else
		{
			if (!GetParent()) 
			{
				af::setSeed(time(NULL));
				afIndexes = af::round(af::randu(iN) * iN)(af::seq(iBatchSize));
			}

			for (uint i = 0; i < GetNumChildren(); i++)
			{
				GPLVMBaseModel<Scalar>& child = dynamic_cast<GPLVMBaseModel<Scalar>&>(*GetChild(i));
				child.SetIndexes(afIndexes);
			}

			yBatch = afY(afIndexes, af::span);
		}

		scaleLogZ = -iN * 1.0 / iBatchSize / dAlpha;
		scalePost = -iN * 1.0 / iBatchSize * (1.0 - 1.0 / dAlpha);
		scaleCavity = -iN * 1.0 / iBatchSize / dAlpha;
		scalePrior = 1.0;

		if(!bIsLatetsFixed)
		{
			CavityLatents(mcav, vcav);
			PosteriorLatents(mpost, vpost);

			// Gradient computation in Moment Matching mode
			// propagate x forward through cavity

			switch (pMode)
			{
			case PropagationMode::MomentMatching:
				slayer.ForwardPredictionCavity(mout, vout, &psi1, &psi2, mcav, &vcav, dAlpha);

				//if (GetNumChildren() > 0)
				//{
				//	int dimCnt = 0;
				//	for (uint i = 0; i < GetNumChildren(); i++)
				//	{
				//		GPLVMBaseModel<Scalar>& child = dynamic_cast<GPLVMBaseModel<Scalar>&>(*GetChild(i));
				//		iStart = iEnd; iEnd += child.GetNumParameters();
				//		child.SetPrior(mout(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)), vout(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)));
				//		logZ += child.Function(af::array(), tmpGradient);
				//		outGradient(seq(iStart, iEnd - 1)) = tmpGradient;
				//		dlogZ_dm = CommonUtil<Scalar>::Join(dlogZ_dm, child.GetMeanGradient(), 1);
				//		dlogZ_dv = CommonUtil<Scalar>::Join(dlogZ_dv, child.GetVarGradient(), 1);

				//		dimCnt += child.GetLatentDimension();
				//	}
				//}
				//else
				//{
				//	// compute log normalizer
				//	logZ = likLayer->ComputeLogZ(mout, vout, yBatch, dAlpha) * scaleLogZ;
				//	/// Gradient computation
				//	likLayer->ComputeLogZGradients(mout, vout, yBatch, &dlogZ_dm, &dlogZ_dv, nullptr, dAlpha);
				//}

				if (GetNumChildren() > 0)
				{
					af::array mf, vf;
					gpLayer->ForwardPredictionPost(&mpost, &vpost, mf, vf);

					int dimCnt = 0;
					for (uint i = 0; i < GetNumChildren(); i++)
					{
						GPLVMBaseModel<Scalar>& child = dynamic_cast<GPLVMBaseModel<Scalar>&>(*GetChild(i));
						iStart = iEnd; iEnd += child.GetNumParameters();
						child.SetPrior(mf(af::span, af::seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)), vf(af::span, af::seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)));
						logZ += child.Function(af::array(), tmpGradient) * scalePost;
						outGradient(af::seq(iStart, iEnd - 1)) = tmpGradient;
						dlogZ_dm = CommonUtil<Scalar>::Join(dlogZ_dm, child.GetMeanGradient(), 1);
						dlogZ_dv = CommonUtil<Scalar>::Join(dlogZ_dv, child.GetVarGradient(), 1);

						dimCnt += child.GetLatentDimension();
					}

					dlogZ_dm_scale = dlogZ_dm * scalePost;
					dlogZ_dv_scale = dlogZ_dv * scalePost;
				}
				else
				{
					// compute log normalizer
					logZ = likLayer->ComputeLogZ(mout, vout, yBatch, dAlpha) * scaleLogZ;
					/// Gradient computation
					likLayer->ComputeLogZGradients(mout, vout, yBatch, &dlogZ_dm, &dlogZ_dv, nullptr, dAlpha);

					dlogZ_dm_scale = dlogZ_dm * scaleLogZ;
					dlogZ_dv_scale = dlogZ_dv * scaleLogZ;
				}



				gpLayerGradients = slayer.BackpropGradientsMM(mout, vout, dlogZ_dm_scale, dlogZ_dv_scale, psi1, psi2, mcav, vcav, &GradInput, dAlpha);
				break;
			case MonteCarlo:
				af::array x, eps;
				slayer.ForwardPredictionRandomCavityMC(mout, vout, x, eps, mcav, vcav, dAlpha);
				logZ = likLayer->ComputeLogZ(mout, vout, yBatch, dAlpha, &dlogZ_dm, &dlogZ_dv, nullptr) * scaleLogZ;
				dlogZ_dm_scale = dlogZ_dm * scaleLogZ;
				dlogZ_dv_scale = dlogZ_dv * scaleLogZ;
				gpLayerGradients = slayer.BackpropGradientsMC(mcav, vcav, eps, dlogZ_dm_scale, dlogZ_dv_scale, x, &GradInput, dAlpha);
				break;
			}

			if (GetNumChildren() == 0)
				likGradient = likLayer->BackpropagationGradients(mout, vout, dlogZ_dm, dlogZ_dv, dAlpha, scaleLogZ);

			// X contributions
			if (GetParent())
			{
				priorPhiX = ComputePhiLatents(afPriorMean, afPriorVariance, &afGradMean, &afGradVariance);
				priorPhiX = 0;
			}
			else
				priorPhiX = ComputePhiLatents(af::constant(dPriorMean, 1, m_dType), af::constant(dPriorVariance, 1, m_dType)) * iN * iq;
			cavPhiX = ComputePhiLatents(mcav, vcav, &dmCav, &dvCav);
			postPhiX = ComputePhiLatents(mpost, vpost, &dmPost, &dvPost);

			xContribution = priorPhiX * scalePrior + cavPhiX * scaleCavity + postPhiX * scalePost;

			dmCav = scaleCavity * dmCav + GradInput["dL_dmx"];
			dvCav = scaleCavity * dvCav + GradInput["dL_dvx"];

			dmPost *= scalePost;
			dvPost *= scalePost;
			gradsXcav = CavityGradientLatents(dmCav, dvCav, mcav, vcav);
			gradsXpost = PosteriorGradientLatents(dmPost, dvPost);

			if (GetNumChildren() == 0)
			{
				/// Collecting gradients
				iEnd += likLayer->GetNumParameters();
				if (iStart != iEnd)
					outGradient(af::seq(iStart, iEnd - 1)) = likGradient;
			}

			iStart = iEnd;
			if (backConst)
			{
				iEnd += backConst->GetNumParameters() * 2;
				//iEnd += iBatchSize * iq;
			}
			else
				iEnd += iBatchSize * iq * 2.0;

			if (mStyles) for (auto style = mStyles->begin(); style != mStyles->end(); style++) iEnd += style->second.GetNumParameters();

			outGradient(af::seq(iStart, iEnd - 1)) = gradsXcav + gradsXpost;
		}
		else
		{
			af::array mx, vx;
			xContribution = 0;
			PosteriorLatents(mx, vx);

			// propagate x forward through cavity
			slayer.ForwardPredictionCavity(mout, vout, nullptr, nullptr, mx, nullptr, dAlpha);

			// compute log normalizer
			logZ = likLayer->ComputeLogZ(mout, vout, yBatch, dAlpha) * scaleLogZ;

			// likelihood contribution from gp layer
			sgpContribution = slayer.ComputePhi(dAlpha);

			likLayer->ComputeLogZGradients(mout, vout, yBatch, &dlogZ_dm, &dlogZ_dv, nullptr, dAlpha);

			dlogZ_dm_scale = dlogZ_dm * scaleLogZ;
			dlogZ_dv_scale = dlogZ_dv * scaleLogZ;

			slayer.BackpropGradientsReg(mout, vout, dlogZ_dm_scale, dlogZ_dv_scale, mx, nullptr, dAlpha);

			likGradient = likLayer->BackpropagationGradients(mout, vout, dlogZ_dm, dlogZ_dv, dAlpha, scaleLogZ);

			iEnd += likLayer->GetNumParameters();
			if (iStart != iEnd)
				outGradient(af::seq(iStart, iEnd - 1)) = likGradient;
		}

		iStart = iEnd; iEnd += gpLayer->GetNumParameters();
		outGradient(af::seq(iStart, iEnd - 1)) = gpLayerGradients;

		// objective computation
		sgpContribution = slayer.ComputePhi(dAlpha);

		//std::cout << "logZ: " << logZ << "xC: " << xContribution << "sgpC: " << sgpContribution << std::endl;

		if (bIsLatetsFixed)
		{
			outGradient /= iN;

			return (logZ + sgpContribution) / iN;
		}
		else
			return logZ + xContribution + sgpContribution;
	}

	template<typename Scalar>
	void SGPLVM<Scalar>::CavityLatents(af::array& mx, af::array& vx)
	{
		af::array t1, t2, t01, t02, cavity_t1, cavity_t2;
		af::array t11 = afFactorX1(afIndexes, af::span);
		af::array t12 = afFactorX2(afIndexes, af::span);

		if (GetParent())
		{
			t01 = afPriorX1(afIndexes, af::span);
			t02 = afPriorX2(afIndexes, af::span);
		}
		else
		{
			t01 = constant(dPriorX1, t11.dims(), m_dType);
			t02 = constant(dPriorX2, t12.dims(), m_dType);
		}

		if (backConst)
		{
			t1 = t11 / t12;
			t2 = 1.0 / t12;
			cavity_t1 = t01 + (t1 - t01) * (1.0 - dAlpha);
			cavity_t2 = t02 + (t2 - t02) * (1.0 - dAlpha);
		}
		else
		{
			cavity_t1 = t01 + (1.0 - dAlpha) * t11;
			cavity_t2 = t02 + (1.0 - dAlpha) * t12;
			
		}
		vx = 1.0 / cavity_t2;
		mx = cavity_t1 / cavity_t2;
	}

	template<typename Scalar>
	af::array SGPLVM<Scalar>::CavityGradientLatents(const af::array& dmx, const af::array& dvx, const af::array& m, const af::array& v)
	{
		af::array gradX, gradStyleFlat, dt1Back, dt2Back;

		af::array t1 = m / v;
		af::array t2 = 1 / v;
		af::array dt1 = (1.0 - dAlpha) * dmx / t2;
		af::array dt2 = (1.0 - dAlpha) * (-dmx * t1 / af::pow(t2, 2.0) - dvx / af::pow(t2, 2.0));

		if (backConst)
		{
			af::array mpost = afFactorX1(afIndexes, af::span);
			af::array vpost = afFactorX2(afIndexes, af::span);
			af::array dm = dt1 / vpost;
			af::array dv = -dt1 * mpost / af::pow(vpost, 2.0) - dt2 / af::pow(vpost, 2.0);
			dt1Back = backConst->BackconstraintGradient(dm(af::span, af::seq(0, iq - 1)));
			dt2Back = backConst->BackconstraintGradient(dv(af::span, af::seq(0, iq - 1)) * 2.0 * vpost(af::span, af::seq(0, iq - 1)));
		}
		else
		{
			dt2 *= 2.0 * afFactorX2(afIndexes, af::span);
		}

		if (mStyles)
		{
			if (backConst)
			{
				dt1 = dt1 / afFactorX2(afIndexes, af::span);
				af::array dv = -dt1 * afFactorX1(afIndexes, af::span) / pow(afFactorX2(afIndexes, af::span), 2.0) - dt2 
					/ af::pow(afFactorX2(afIndexes, af::span), 2.0);
				dt2 = (dv * 2.0 * afFactorX2(afIndexes, af::span));
			}

			int sStart = iq, sEnd = iq;
			for (auto style = mStyles->begin(); style != mStyles->end(); style++)
			{
				sStart = sEnd; sEnd += style->second.GetNumSubstyles();

				gradStyleFlat = CommonUtil<Scalar>::Join(gradStyleFlat, style->second.GetGradientCollapsed(dt1(af::span, af::seq(sStart, sEnd - 1))));
				gradStyleFlat = CommonUtil<Scalar>::Join(gradStyleFlat, style->second.GetGradientCollapsed(dt2(af::span, af::seq(sStart, sEnd - 1))));
			}

			dt1 = dt1(af::span, af::seq(0, iq - 1));
			dt2 = dt2(af::span, af::seq(0, iq - 1));
		}

		if (backConst)
		{
			dt1 = dt1Back;
			dt2 = dt2Back;
		}

		gradX = CommonUtil<Scalar>::Join(gradX, flat(dt1));
		gradX = CommonUtil<Scalar>::Join(gradX, flat(dt2));
		
		if (mStyles)
		{
			gradX = CommonUtil<Scalar>::Join(gradX, gradStyleFlat);
		}

		return gradX;
	}

	template<typename Scalar>
	Scalar SGPLVM<Scalar>::ComputePhiLatents(const af::array& mx, const af::array& vx, af::array* dmx, af::array* dvx)
	{
		if (dmx != nullptr)
			*dmx = mx / vx;

		if (dvx != nullptr)
			*dvx = 0.5 * (-pow(mx, 2.0) / pow(vx, 2.0) + 1.0 / vx);

		return af::sum<Scalar>(af::sum(0.5 * (pow(mx, 2.0) / vx + log(vx))));
	}
}