/**
File:		MachineLearning/GPModels/Models/GPModels/FgPEPSparseGPR2nd.cpp

Author:		
Email:		
Site:       

Copyright (c) 2018 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgPEPSparseGPR2nd.h>

namespace NeuralEngine::MachineLearning::GPModels::PowerEP
{
	template class SGPR2nd<float>;
	template class SGPR2nd<double>;

	template<typename Scalar>
	SGPR2nd<Scalar>::SGPR2nd(const af::array& Y, const af::array& X, int numInducing, LogLikType lType)
		: SparseGPBaseModel<Scalar>(Y, X, numInducing, lType)
	{
		SGPR2nd<Scalar>();
	}

	template<typename Scalar>
	void SGPR2nd<Scalar>::Inference(Scalar alpha, int numIter, bool parallelUpdate, Scalar decay)
	{
		Scalar logZ;
		af::array x(m_dType), y(m_dType), mout(m_dType), vout(m_dType), dlogZ_dm(m_dType), dlogZ_dv(m_dType), dlogZ_dm2(m_dType), idx;
		std::map<std::string, af::array> gradHyper, gradCav;

		SGPLayer2nd<Scalar>& slayer = dynamic_cast<SGPLayer2nd<Scalar>&>(*gpLayer);

		for (int it = 0; it < numIter; it++)
		{
			if (!parallelUpdate)
			{
				for (int n = 0; n < iN; n++)
				{
					idx = af::constant(n, 1, 1);
					y = afY(idx, af::span);
					x = afX(idx, af::span);

					slayer.ForwardPredictionCavity(mout, vout, idx, x);
					logZ = likLayer->ComputeLogZ(mout, vout, y);
					likLayer->ComputeLogZGradients(mout, vout, y, &dlogZ_dm, &dlogZ_dv, &dlogZ_dm2);

					slayer.BackpropGradientsReg(mout, vout, dlogZ_dm, dlogZ_dv, x, gradHyper, gradCav);
					slayer.UpdateFactor(idx, gradCav, 1.0, 1.0);
				}
			}
			else
			{
				idx = af::seq(iN);
				y = afY(idx, af::span);
				x = afX(idx, af::span);

				slayer.ForwardPredictionCavity(mout, vout, idx, x);
				logZ = likLayer->ComputeLogZ(mout, vout, y);
				likLayer->ComputeLogZGradients(mout, vout, y, &dlogZ_dm, &dlogZ_dv, &dlogZ_dm2);

				slayer.BackpropGradientsReg(mout, vout, dlogZ_dm, dlogZ_dv, x, gradHyper, gradCav);
				slayer.UpdateFactor(idx, gradCav, 1.0, 1.0);
			}
		}
	}

	template<typename Scalar>
	SGPR2nd<Scalar>::SGPR2nd()
	{
		gpLayer = new SGPLayer2nd<Scalar>(iN, ik, iD, iq);
		gpLayer->InitParameters(&afX);
	}
}