/**
File:		MachineLearning/GPModels/Models/GPModels/FgAEPSparseGPLVM.cpp

Author:		
Email:		
Site:       

Copyright (c) 2018 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgAEPSparseDGPLVM.h>
#include <MachineLearning/FgAEPSparseGPLayer.h>

namespace NeuralEngine::MachineLearning::GPModels::AEP
{
	template class SDGPLVM<double>;
	template class SDGPLVM<float>;

	template<typename Scalar>
	SDGPLVM<Scalar>::SDGPLVM(const af::array& Y, int latentDimension, HiddenLayerDescription description, Scalar alpha, Scalar priorMean, Scalar priorVariance, PropagationMode probMode, LogLikType lType, XInit emethod)
		: SparseDeepGPLVMBaseModel(Y, latentDimension, description, priorMean, priorVariance, lType, emethod), dAlpha(alpha), pMode(probMode)
	{
		for (int i = 0; i < iNumLayer; i++)
			gpLayer.push_back(new SGPLayer<Scalar>(iN, vNumPseudosPerLayer[i], vSize[i + 1], vSize[i]));
	}

	template<typename Scalar>
	SDGPLVM<Scalar>::SDGPLVM(const af::array& Y, int latentDimension, std::vector<HiddenLayerDescription> descriptions, Scalar alpha, Scalar priorMean, Scalar priorVariance, PropagationMode probMode, LogLikType lType, XInit emethod)
		: SparseDeepGPLVMBaseModel(Y, latentDimension, descriptions, priorMean, priorVariance, lType, emethod), dAlpha(alpha), pMode(probMode)
	{
		for (int i = 0; i < iNumLayer; i++)
			gpLayer.push_back(new SGPLayer<Scalar>(iN, vNumPseudosPerLayer[i], vSize[i + 1], vSize[i]));
	}

	template<typename Scalar>
	SDGPLVM<Scalar>::SDGPLVM()
		: SparseDeepGPLVMBaseModel(), dAlpha(0.5)
	{
	}

	template<typename Scalar>
	SDGPLVM<Scalar>::~SDGPLVM()
	{
	}

	template<typename Scalar>
	Scalar SDGPLVM<Scalar>::Function(const af::array& x, af::array& outGradient)
	{
		if (!x.isempty())
			SetParameters(x);

		Scalar logZ = 0.0, scaleLogZ, scalePrior, scalePost, scaleCavity, xContribution, sgpContribution, priorPhiX, cavPhiX, postPhiX, likGradient;

		af::array yBatch, mcav, vcav, mpost, vpost, gpLayerGradients;

		af::array dlogZ_dm, dlogZ_dv, dlogZ_dmi, dlogZ_dvi, tmpGradient, dmCav, dvCav, dmPost, dvPost, gradsXcav, gradsXpost, gradsXParent, gradsXParentCav, 
			dlogZ_dmDown, dlogZ_dmCavDown, dlogZ_dvDown, dlogZ_dvCavDown;

		std::vector<af::array> mout, vout, psi1, psi2;
		af::array mi(m_dType), vi(m_dType), psi1i(m_dType), psi2i(m_dType);

		outGradient = af::constant(0.0f, GetNumParameters(), (m_dType));
		std::map<std::string, af::array> GradInput;

		int iStart = 0, iEnd = 0;

		if (iBatchSize >= iN)
		{
			iBatchSize = iN;
			yBatch = afY;
		}
		else
		{
			if (!GetParent())
			{
				af::setSeed(time(NULL));
				afIndexes = af::round(af::randu(iN) * iN)(af::seq(iBatchSize));
			}

			for (uint i = 0; i < GetNumChildren(); i++)
			{
				GPLVMBaseModel<Scalar>& child = dynamic_cast<GPLVMBaseModel<Scalar>&>(*GetChild(i));
				child.SetIndexes(afIndexes);
			}

			yBatch = afY(afIndexes, af::span);
		}

		scaleLogZ = -iN * 1.0 / iBatchSize / dAlpha;
		scalePost = -iN * 1.0 / iBatchSize * (1.0 - 1.0 / dAlpha);
		scaleCavity = -iN * 1.0 / iBatchSize / dAlpha;
		scalePrior = 1.0;

		CavityLatents(mcav, vcav);
		PosteriorLatents(mpost, vpost);

		sgpContribution = 0;
		int iStartLayer = 0, iEndLayer = 0;
		switch (pMode)
		{
			case PropagationMode::MomentMatching:
				for (uint i = 0; i < iNumLayer; i++)
				{
					SGPLayer<Scalar>& slayer = dynamic_cast<SGPLayer<Scalar>&>(*gpLayer[i]);
					if (i == 0)
					{
						slayer.ForwardPredictionCavity(mi, vi, &psi1i, &psi2i, mcav, &vcav, dAlpha);
						mout.push_back(mi);
						vout.push_back(vi);
						psi1.push_back(psi1i);
						psi2.push_back(psi2i);
					}
					else
					{
						slayer.ForwardPredictionCavity(mi, vi, &psi1i, &psi2i, mout[i - 1], &vout[i - 1], dAlpha);
						mout.push_back(mi);
						vout.push_back(vi);
						psi1.push_back(psi1i);
						psi2.push_back(psi2i);
					}
					sgpContribution += slayer.ComputePhi(dAlpha);
				}

				if (GetNumChildren() > 0)
				{
					af::array mf, vf, my, vy, myCav, vyCav, mfTmp, vfTmp, dlogZ_dmC, dlogZ_dvC;

					for (uint j = 0; j < iNumLayer; j++)
					{
						if (j == 0) gpLayer[j]->ForwardPredictionPost(&mpost, &vpost, mf, vf);
						else
						{
							gpLayer[j]->ForwardPredictionPost(&mf, &vf, mfTmp, vfTmp);
							mf = mfTmp.copy();
							vf = vfTmp.copy();
						}
					}
					//likLayer->ProbabilisticOutput(mf, vf, my, vy);
					likLayer->ProbabilisticOutput(mout.back(), vout.back(), myCav, vyCav);
			
					//likLayer->ComputeLogZGradients(mf, vf, yBatch, &dlogZ_dmDown, &dlogZ_dvDown, nullptr, dAlpha);
					likLayer->ComputeLogZGradients(mout.back(), vout.back(), yBatch, &dlogZ_dmCavDown, &dlogZ_dvCavDown, nullptr, dAlpha);
					likLayer->ComputeLogZGradients(mf, vf, yBatch, &dlogZ_dmDown, &dlogZ_dvDown, nullptr, dAlpha);

					dlogZ_dmCavDown *= scaleCavity;
					dlogZ_dvCavDown *= scaleCavity;
					/*dlogZ_dmDown *= scalePost;
					dlogZ_dvDown *= scalePost;*/

					int dimCnt = 0;
					for (uint i = 0; i < GetNumChildren(); i++)
					{
						GPLVMBaseModel<Scalar>& child = dynamic_cast<GPLVMBaseModel<Scalar>&>(*GetChild(i));
						iStart = iEnd; iEnd += child.GetNumParameters();
						//child.SetPriorCavity(mout.back()(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)), vout.back()(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)));
						//child.SetPriorCavity(myCav(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)), vyCav(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)));
						//child.SetPrior(mf(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)), vf(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)));
						child.SetPrior(mout.back()(af::span, af::seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)), vout.back()(af::span, af::seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)));
						//child.SetPriorCavity(mout.back()(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)), vout.back()(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)));
						//child.SetLatentGradient(dlogZ_dmDown(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)), dlogZ_dvDown(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)));
						//child.SetLatentGradientCavity(dlogZ_dmCavDown(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)), dlogZ_dvCavDown(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)));

						//child.SetPrior(mout.back()(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)), vout.back()(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)));
				
						//likLayer->ComputeLogZGradients(my(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)), vy(span, seq(dimCnt, dimCnt + child.GetLatentDimension() - 1)), yBatch, &dlogZ_dmDown, &dlogZ_dvDown, nullptr, dAlpha);

						logZ += child.Function(af::array(), tmpGradient);
						outGradient(af::seq(iStart, iEnd - 1)) = tmpGradient;
						dlogZ_dmC = CommonUtil<Scalar>::Join(dlogZ_dmC, child.GetMeanGradient(), 1);
						dlogZ_dvC = CommonUtil<Scalar>::Join(dlogZ_dvC, child.GetVarGradient(), 1);

						dimCnt += child.GetLatentDimension();
					}
					dlogZ_dmi = dlogZ_dmC/* * scaleCavity*/;
					dlogZ_dvi = dlogZ_dvC/* * scaleCavity*/;

					logZ += likLayer->ComputeLogZ(mout.back(), vout.back(), yBatch, dAlpha) * scaleLogZ;

					likLayer->ComputeLogZGradients(mout.back(), vout.back(), yBatch, &dlogZ_dm, &dlogZ_dv, nullptr, dAlpha);

					/*dlogZ_dm += dlogZ_dmC;
					dlogZ_dv += dlogZ_dvC;*/

					dlogZ_dmi = dlogZ_dmi + dlogZ_dm * scaleLogZ;
					dlogZ_dvi = dlogZ_dvi + dlogZ_dv * scaleLogZ;
				}
				else
				{
					// compute log normalizer
					logZ = likLayer->ComputeLogZ(mout.back(), vout.back(), yBatch, dAlpha) * scaleLogZ;

					likLayer->ComputeLogZGradients(mout.back(), vout.back(), yBatch, &dlogZ_dm, &dlogZ_dv, nullptr, dAlpha);

					dlogZ_dmi = dlogZ_dm * scaleLogZ;
					dlogZ_dvi = dlogZ_dv * scaleLogZ;
				}

				// Backpropagation
				gpLayerGradients = af::constant(0.0, GetNumGPLayerParameters(), m_dType);
				//int iStartLayer = 0, iEndLayer = 0;
				for (int i = iNumLayer - 1; i >= 0; i--)
				{
					SGPLayer<Scalar>& slayer = dynamic_cast<SGPLayer<Scalar>&>(*gpLayer[i]);

					iEndLayer += slayer.GetNumParameters();
					if (i == 0)
						gpLayerGradients(af::seq(iStartLayer, iEndLayer - 1)) = slayer.BackpropGradientsMM(mout[i], vout[i], dlogZ_dmi, dlogZ_dvi, psi1[i], psi2[i], mcav, vcav, &GradInput, dAlpha);
					else
					{
						gpLayerGradients(af::seq(iStartLayer, iEndLayer - 1)) = slayer.BackpropGradientsMM(mout[i], vout[i], dlogZ_dmi, dlogZ_dvi, psi1[i], psi2[i], mout[i - 1], vout[i - 1], &GradInput, dAlpha);
						dlogZ_dmi = GradInput["dL_dmx"];
						dlogZ_dvi = GradInput["dL_dvx"];
					}
					iStartLayer = iEndLayer;
				}
				break;

			case MonteCarlo:
				for (uint i = 0; i < iNumLayer; i++)
				{
					SGPLayer<Scalar>& slayer = dynamic_cast<SGPLayer<Scalar>&>(*gpLayer[i]);
					if (i == 0)
					{
						slayer.ForwardPredictionRandomCavityMC(mi, vi, psi1i, psi2i, mcav, vcav, dAlpha);

						mi = moddims(mi, 5 * mi.dims(1), mi.dims(2));
						vi = moddims(vi, 5 * vi.dims(1), vi.dims(2));

						mout.push_back(mi);
						vout.push_back(vi);
						psi1.push_back(psi1i);
						psi2.push_back(psi2i);
					}
					else
					{
						slayer.ForwardPredictionCavity(mi, vi, &psi1i, &psi2i, mout[i - 1], &vout[i - 1], dAlpha);
						mout.push_back(mi);
						vout.push_back(vi);
						psi1.push_back(psi1i);
						psi2.push_back(psi2i);
					}
					sgpContribution += slayer.ComputePhi(dAlpha);
				}

				logZ = likLayer->ComputeLogZ(mout.back(), vout.back(), yBatch, dAlpha) * scaleLogZ;

				likLayer->ComputeLogZGradients(mout.back(), vout.back(), yBatch, &dlogZ_dm, &dlogZ_dv, nullptr, dAlpha);

				dlogZ_dmi = dlogZ_dm * scaleLogZ;
				dlogZ_dvi = dlogZ_dv * scaleLogZ;

				// Backpropagation
				gpLayerGradients = af::constant(0.0, GetNumGPLayerParameters(), m_dType);
				for (int i = iNumLayer - 1; i >= 0; i--)
				{
					SGPLayer<Scalar>& slayer = dynamic_cast<SGPLayer<Scalar>&>(*gpLayer[i]);

					iEndLayer += slayer.GetNumParameters();
					if (i == 0)
						gpLayerGradients(af::seq(iStartLayer, iEndLayer - 1)) = slayer.BackpropGradientsMC(mcav, vcav, psi2[i], dlogZ_dmi, dlogZ_dvi, psi1[i], &GradInput, dAlpha);
					else
					{
						gpLayerGradients(af::seq(iStartLayer, iEndLayer - 1)) = slayer.BackpropGradientsMM(mout[i], vout[i], dlogZ_dmi, dlogZ_dvi, psi1[i], psi2[i], mout[i - 1], vout[i - 1], &GradInput, dAlpha);
						dlogZ_dmi = GradInput["dL_dmx"];
						dlogZ_dvi = GradInput["dL_dvx"];
					}
					iStartLayer = iEndLayer;
				}

				break;
		}

		iStart = iEnd; iEnd += likLayer->GetNumParameters();
		if (iStart != iEnd)
			outGradient(af::seq(iStart, iEnd - 1)) = likLayer->BackpropagationGradients(mout.back(), vout.back(), dlogZ_dm, dlogZ_dv, dAlpha, scaleLogZ);


		// X contributions
		if (GetParent())
		{
			priorPhiX = ComputePhiLatents(afPriorMean, afPriorVariance, &afGradMean, &afGradVariance);
			/*gradsXParent = LatentGradient(afGradMean, afGradVariance);
			gradsXParentCav = LatentGradient(afGradMeanCav, afGradVarianceCav);*/
			//gradsXParent = PosteriorGradientLatents(afGradMean, afGradVariance);
			priorPhiX = 0;
		}
		else
		{
			priorPhiX = ComputePhiLatents(af::constant(dPriorMean, 1, m_dType), af::constant(dPriorVariance, 1, m_dType)) * iN * iq;
			/*gradsXParent = constant(0.0, iBatchSize * iq * 2.0, m_dType);
			gradsXParentCav = constant(0.0, iBatchSize * iq * 2.0, m_dType);*/
		}
			

		cavPhiX = ComputePhiLatents(mcav, vcav, &dmCav, &dvCav);
		postPhiX = ComputePhiLatents(mpost, vpost, &dmPost, &dvPost);

		xContribution = priorPhiX * scalePrior + cavPhiX * scaleCavity + postPhiX * scalePost;

		dmCav = scaleCavity * dmCav + GradInput["dL_dmx"];
		dvCav = scaleCavity * dvCav + GradInput["dL_dvx"];

		dmPost *= scalePost;
		dvPost *= scalePost;
		gradsXcav = CavityGradientLatents(dmCav, dvCav, mcav, vcav);
		gradsXpost = PosteriorGradientLatents(dmPost, dvPost);

		iStart = iEnd;
		if (backConst)
		{
			iEnd += backConst->GetNumParameters();
			iEnd += iBatchSize * iq;
		}
		else
			iEnd += iBatchSize * iq * 2.0;

		if (mStyles) for (auto style = mStyles->begin(); style != mStyles->end(); style++) iEnd += style->second.GetNumParameters();

		outGradient(af::seq(iStart, iEnd - 1)) = gradsXcav + gradsXpost/* + gradsXParent*/;

		iStart = iEnd; iEnd += GetNumGPLayerParameters();
		outGradient(af::seq(iStart, iEnd - 1)) = gpLayerGradients;

		return logZ + xContribution + sgpContribution;
	}

	template<typename Scalar>
	void SDGPLVM<Scalar>::CavityLatents(af::array& mx, af::array& vx)
	{
		af::array t1, t2, t01, t02, cavity_t1, cavity_t2;
		af::array t11 = afFactorX1(afIndexes, af::span);
		af::array t12 = afFactorX2(afIndexes, af::span);

		if (GetParent())
		{
			t01 = afPriorX1(afIndexes, af::span);
			t02 = afPriorX2(afIndexes, af::span);
		}
		else
		{
			t01 = constant(dPriorX1, t11.dims(), m_dType);
			t02 = constant(dPriorX2, t12.dims(), m_dType);
		}

		if (backConst)
		{
			t1 = t11 / t12;
			t2 = 1.0 / t12;
			cavity_t1 = t01 + (t1 - t01) * (1.0 - dAlpha);
			cavity_t2 = t02 + (t2 - t02) * (1.0 - dAlpha);
		}
		else
		{
			cavity_t1 = t01 + (1.0 - dAlpha) * t11;
			cavity_t2 = t02 + (1.0 - dAlpha) * t12;

		}

		/*af::array cavity_t1 = t01 + (1.0 - dAlpha) * t11;
		af::array cavity_t2 = t02 + (1.0 - dAlpha) * t12;*/
		vx = 1.0 / cavity_t2;
		mx = cavity_t1 / cavity_t2;
	}

	template<typename Scalar>
	af::array SDGPLVM<Scalar>::CavityGradientLatents(const af::array& dmx, const af::array& dvx, const af::array& m, const af::array& v)
	{
		af::array gradX, dt1Back, dt2Back, gradStyleFlat;

		int iStart = 0, iEnd = iBatchSize * iq;

		af::array t1 = m / v;
		af::array t2 = 1 / v;
		af::array dt1 = (1.0 - dAlpha) * dmx / t2;
		af::array dt2 = (1.0 - dAlpha) * (-dmx * t1 / pow(t2, 2.0) - dvx / pow(t2, 2.0));

		if (backConst)
		{
			af::array mpost = afFactorX1(afIndexes, af::span);
			af::array vpost = afFactorX2(afIndexes, af::span);
			af::array dm = dt1 / vpost;
			af::array dv = -dt1 * mpost / pow(vpost, 2.0) - dt2 / pow(vpost, 2.0);
			dt1Back = backConst->BackconstraintGradient(dm(af::span, af::seq(0, iq - 1)));
			dt2Back = (dv * 2.0 * vpost)(af::span, af::seq(0, iq - 1));
		}
		else
			dt2 *= 2.0 * afFactorX2(afIndexes, af::span);

		if (mStyles)
		{
			if (backConst)
			{
				dt1 = dt1 / afFactorX2(afIndexes, af::span);
				af::array dv = -dt1 * afFactorX1(afIndexes, af::span) / pow(afFactorX2(afIndexes, af::span), 2.0) - dt2 / pow(afFactorX2(afIndexes, af::span), 2.0);
				dt2 = (dv * 2.0 * afFactorX2(afIndexes, af::span));
			}

			int sStart = iq, sEnd = iq;
			for (auto style = mStyles->begin(); style != mStyles->end(); style++)
			{
				sStart = sEnd; sEnd += style->second.GetNumSubstyles();

				gradStyleFlat = CommonUtil<Scalar>::Join(gradStyleFlat, style->second.GetGradientCollapsed(dt1(af::span, af::seq(sStart, sEnd - 1))));
				gradStyleFlat = CommonUtil<Scalar>::Join(gradStyleFlat, style->second.GetGradientCollapsed(dt2(af::span, af::seq(sStart, sEnd - 1))));
			}

			dt1 = dt1(af::span, af::seq(0, iq - 1));
			dt2 = dt2(af::span, af::seq(0, iq - 1));
		}

		if (backConst)
		{
			dt1 = dt1Back;
			dt2 = dt2Back;
		}

		gradX = CommonUtil<Scalar>::Join(gradX, flat(dt1));
		gradX = CommonUtil<Scalar>::Join(gradX, flat(dt2));

		if (mStyles)
		{
			gradX = CommonUtil<Scalar>::Join(gradX, gradStyleFlat);
		}

		return gradX;
	}

	template<typename Scalar>
	Scalar SDGPLVM<Scalar>::ComputePhiLatents(const af::array& mx, const af::array& vx, af::array* dmx, af::array* dvx)
	{
		if (dmx != nullptr)
			* dmx = mx / vx;

		if (dvx != nullptr)
			* dvx = 0.5 * (-pow(mx, 2.0) / pow(vx, 2.0) + 1.0 / vx);

		return af::sum<Scalar>(af::sum(0.5 * (pow(mx, 2.0) / vx + log(vx))));
	}
}